{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "67b1281e-b7de-491a-b3b3-7572eaecc4d7",
      "metadata": {
        "id": "67b1281e-b7de-491a-b3b3-7572eaecc4d7"
      },
      "source": [
        "# CT LAB: Simulating and Reconstructiong Tomography Data\n",
        "\n",
        "Authors: L. Calatroni, A. Sebastiani (MaLGa, Unige)\n",
        "\n",
        "Mini-course \"Computational Imaging & Learning\" - MSc in Data Science, University of Padua, Italy."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "8a3b8113",
      "metadata": {},
      "source": [
        "Welcome to the CT computational lab. The core of the **Computed Tomography (CT)** problem lies in recovering a hidden image from a set of projections. While a single X-ray provides only a flattened shadow of the object of interest, CT measures the total X-ray attenuation along a certain number of different paths.\n",
        "\n",
        "The goal is to reconstruct precise internal cross-sections, the raw data—represented as a sinogram—consists of discrete, noisy line integrals corrupted by noise.  The objective is to formulate a rigorous forward model of the scanning process based on the discrete Radon transform and solve the resulting ill-posed inverse problem to recover the tissue attenuation map. We will utilize the `deepinv` library to implement optimization algorithms for Maximum Likelihood and Maximum a Posteriori estimation."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "OBY_UrfG2dpc",
      "metadata": {
        "id": "OBY_UrfG2dpc"
      },
      "source": [
        "*First, run the setup cell below to install the necessary libraries and load all the tools.*"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GxY8vq4PMnFX",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "GxY8vq4PMnFX",
        "outputId": "f83128a6-b49a-473a-a28d-6a949a2add3f"
      },
      "outputs": [],
      "source": [
        "!pip install deepinv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "af6063fb-5d4a-42f4-99fc-5ee61ccd8960",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "af6063fb-5d4a-42f4-99fc-5ee61ccd8960",
        "outputId": "4aa8e42e-718e-435e-a2ea-b579c13fbfe8"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "# Plots\n",
        "import matplotlib.pyplot as plt\n",
        "\n",
        "# Importing images and basic operations\n",
        "from skimage.data import shepp_logan_phantom\n",
        "from skimage.color import rgb2gray\n",
        "from skimage.transform import resize\n",
        "\n",
        "import deepinv as dinv\n",
        "import torch\n",
        "\n",
        "# Check if GPU is available\n",
        "device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else \"cpu\"\n",
        "print(f'Device is {device}')\n",
        "\n",
        "# Use parallel dataloader if using a GPU to fasten training.\n",
        "num_workers = 5 if torch.cuda.is_available() else 0\n",
        "\n",
        "dtype = torch.float32\n",
        "circle = False\n",
        "\n",
        "MSE = dinv.loss.metric.MSE()\n",
        "PSNR = dinv.loss.metric.PSNR()\n",
        "SSIM = dinv.loss.metric.SSIM()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "qwj3FIcC2ojf",
      "metadata": {
        "id": "qwj3FIcC2ojf"
      },
      "source": [
        "# Image loading"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "8cbb9c75-72df-41eb-a49d-4dbc6cf5c229",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 355
        },
        "id": "8cbb9c75-72df-41eb-a49d-4dbc6cf5c229",
        "outputId": "fd29c817-5ed8-4456-8f60-0697c0e4ef82"
      },
      "outputs": [],
      "source": [
        "# A benchmark image for medical imaging problems\n",
        "u_SL = shepp_logan_phantom()\n",
        "\n",
        "# Reshape the image to a smaller size (with piecewise constant interpolation)\n",
        "N_SL_small = 128\n",
        "u_SL_small = resize(u_SL, (N_SL_small,N_SL_small), order=0,preserve_range=True,anti_aliasing=True)\n",
        "\n",
        "\n",
        "u_SL = torch.tensor(u_SL).unsqueeze(0).unsqueeze(0).to(dtype).to(device)\n",
        "u_SL_small = torch.tensor(u_SL_small).unsqueeze(0).unsqueeze(0).to(dtype).to(device)\n",
        "\n",
        "\n",
        "dinv.utils.plot([u_SL,u_SL_small], ['Full image', 'Small image'], figsize=(6,4))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "20ad4a6a-d13d-4ece-8217-400210dc1b34",
      "metadata": {
        "id": "20ad4a6a-d13d-4ece-8217-400210dc1b34"
      },
      "source": [
        "# **Radon Transform - Computed Tomography**\n",
        "\n",
        "Collect integrals of an image $f$ along straight lines:\n",
        "\n",
        "$$ \\begin{aligned} \\mathbf{y}(\\boldsymbol{\\theta},\\rho) &= \\int_{L_{\\boldsymbol{\\theta},\\rho}}~u~d\\sigma\n",
        "\\end{aligned} \\qquad \\boldsymbol{\\theta} \\in \\R^2,~ \\| \\boldsymbol{\\theta}\\|=1, \\quad \\rho \\in \\mathbb{R}_+$$\n",
        "\n",
        "($\\mathbf{y}$ is usually referred to as the *sinogram* of $u$ and $L_{\\boldsymbol{\\theta},\\rho} =  \\left\\{ \\mathbf{x}: \\mathbf{x}\\cdot \\boldsymbol{\\theta} = \\rho \\right\\}$)\n",
        "\n",
        "Subsampling can be motivated by physical limitations in the acquisition procedure and by the goal of reducing the exposition to X-rays\n",
        "*   **sparse** angles: (uniformly) subsample the angle $\\omega$ such that $\\boldsymbol{\\theta} = (\\cos\\omega, \\sin\\omega)$.\n",
        "*   **limited** angles: acquire angles only in a limited wedge $[\\omega_{\\text{min}},\\omega_{\\text{max}}]$"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cHrJFhTt294z",
      "metadata": {
        "id": "cHrJFhTt294z"
      },
      "source": [
        "# Defining the Forward model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "yy4Cl0ecSlkG",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "yy4Cl0ecSlkG",
        "outputId": "859fbff2-1bf5-43cd-b2f6-a598dd798593"
      },
      "outputs": [],
      "source": [
        "# define a forward model with 180 angles in the full angular range [0, 180]\n",
        "img_width = N_SL_small\n",
        "N_theta_full = # TODO\n",
        "theta_full = # TODO np.linspace(#, #, #, endpoint=False)\n",
        "theta_full = torch.tensor(theta_full)\n",
        "\n",
        "tomo_full = dinv.physics.Tomography(angles=theta_full, img_width=img_width, circle=circle, device=device)\n",
        "\n",
        "sinogram_full = tomo_full(u_SL_small)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "Os6YzrGZjjSZ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Os6YzrGZjjSZ",
        "outputId": "7a6eb0d6-7b0b-4e64-f440-a1bee6bec807"
      },
      "outputs": [],
      "source": [
        "# define a sparse angle CT forward model with 36 angles in the full angular range [0, 180]\n",
        "N_theta_sparse = # TODO\n",
        "theta_sparse = # TODO np.linspace(#, #, #, endpoint=False)\n",
        "theta_sparse = torch.tensor(theta_sparse)\n",
        "\n",
        "tomo_sparse = dinv.physics.Tomography(angles=theta_sparse, img_width=img_width, circle=circle, device=device)\n",
        "\n",
        "sinogram_sparse = tomo_sparse(u_SL_small)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "MNh7OC1Ej65C",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MNh7OC1Ej65C",
        "outputId": "d8200dd7-e070-4c8a-ce24-934dabef5c9d"
      },
      "outputs": [],
      "source": [
        "# define a limited angle CT forward model within the angular range [30, 150]\n",
        "Theta_min = # TODO\n",
        "Theta_max = # TODO\n",
        "N_theta_lim = np.round((Theta_max-Theta_min)/180*N_theta_full).astype('int')\n",
        "theta_limited = # TODO\n",
        "theta_limited = torch.tensor(theta_limited)\n",
        "\n",
        "tomo_limited = dinv.physics.Tomography(angles=theta_limited, img_width=img_width, circle=circle, device=device)\n",
        "\n",
        "sinogram_limited = tomo_limited(u_SL_small)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "kIm8wmZegMKs",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 355
        },
        "id": "kIm8wmZegMKs",
        "outputId": "54ae6593-a22c-4b34-cf63-eaab13ccca86"
      },
      "outputs": [],
      "source": [
        "# fill the entire sinograms with the limited data sinograms\n",
        "fill_sinogram_sparse = torch.zeros_like(sinogram_full)\n",
        "fill_sinogram_sparse[..., theta_sparse.int()]=# TODO\n",
        "\n",
        "fill_sinogram_limited = torch.zeros_like(sinogram_full)\n",
        "fill_sinogram_limited[..., theta_limited.int()]=# TODO\n",
        "\n",
        "img_list = [u_SL_small, sinogram_full, fill_sinogram_sparse, fill_sinogram_limited]\n",
        "title_list = ['Image', 'Full sinogram', 'Sparse-angle sinogram', 'Limited-angle sinogram']\n",
        "\n",
        "dinv.utils.plot(img_list, title_list, figsize=(12,4))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "id": "KEtoontsn3_c",
      "metadata": {
        "id": "KEtoontsn3_c"
      },
      "outputs": [],
      "source": [
        "# For adding Gaussian noise\n",
        "class GaussianNoiseCT(dinv.physics.GaussianNoise):\n",
        "  def __init__(\n",
        "        self,\n",
        "        sigma: float | torch.Tensor = 0.1,\n",
        "        rng: torch.Generator | None = None,\n",
        "    ):\n",
        "        super().__init__(sigma=sigma, rng=rng)\n",
        "\n",
        "  def forward(self, x, sigma=None, seed=None, **kwargs):\n",
        "        self.update_parameters(sigma=sigma, **kwargs)\n",
        "        self.to(x.device)\n",
        "        return (\n",
        "            x\n",
        "            + torch.max(torch.abs(x))*self.randn_like(x, seed=seed)\n",
        "            * self.sigma[(...,) + (None,) * (x.dim() - 1)]\n",
        "        )"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "f1cc4f41-3fd0-41e4-a18a-023508c43171",
      "metadata": {
        "id": "f1cc4f41-3fd0-41e4-a18a-023508c43171"
      },
      "source": [
        "# **Näive inversion**\n",
        "\n",
        "An inverse problem can be representing as recovering $f$ from\n",
        "$$ y = \\operatorname{noise}(A(u))$$\n",
        "Let $A$ be linear (and additive noise): in the discrete formulation, this reduces to\n",
        "$$ y = \\mathbf{A} u + ϵ, \\quad u \\in \\mathbb{R}^n, \\epsilon \\in \\mathbb{R}^m, \\mathbf{A} \\in \\mathbb{R}^{m \\times n} $$\n",
        "\n",
        "Easy idea to solve it:\n",
        "$$ u_{\\text{naive}} = \\mathbf{A}^{-1} y$$\n",
        "\n",
        "For the Radon tranform, we used the Filtered Back-Projection (FBP) to ``invert'' the measurements."
      ]
    },
    {
      "cell_type": "markdown",
      "id": "MKdhx8mE31st",
      "metadata": {
        "id": "MKdhx8mE31st"
      },
      "source": [
        "**TASK 1: Experiment with the Forward Model**\n",
        "\n",
        "Change the variables in the cells above and re-run them to see the effects!\n",
        "* **Change the phisyics:** Modify `tomo_exp` (physics) to see how the size of the sinogram changes.\n",
        "* **Change the noise:** Adjust the `nl` parameter. How does the image look when the noise level increase?\n",
        "* **Observe the reconstructions:** Compare the reconstructions obtained with the FBP under different acquisition geometries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "010Jpn9eq6K6",
      "metadata": {
        "id": "010Jpn9eq6K6"
      },
      "outputs": [],
      "source": [
        "# Selection operator for experiments, select any of the previous ones\n",
        "tomo_exp = # TODO"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "T182LrbXhDpl",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 839
        },
        "id": "T182LrbXhDpl",
        "outputId": "ab74b158-1626-4f1d-975e-b3fbd5e7f5bc"
      },
      "outputs": [],
      "source": [
        "sinogram = tomo_exp(u_SL_small)\n",
        "u_rec = tomo_exp.fbp(sinogram)\n",
        "\n",
        "# add noise to the sinogram with noise level 1e-3\n",
        "nl1 = # TODO\n",
        "noise_ph = # TODO\n",
        "sinogram_noisy = noise_ph(sinogram)\n",
        "# compute the FBP of the noisy sinogram\n",
        "u_rec_noisy = # TODO\n",
        "\n",
        "# add noise to the sinogram with noise level 1e-2\n",
        "nl2 = # TODO\n",
        "noise_ph2 = # TODO\n",
        "sinogram_noisy2 = noise_ph2(sinogram)\n",
        "# compute the FBP of the noisy sinogram\n",
        "u_rec_noisy2 = # TODO\n",
        "\n",
        "titles_sino = ['Noiseless', f'nl = {nl1}', f'nl = {nl2}']\n",
        "titles_rec = ['Noiseless recon', f'Noisy recon nl = {nl1}', f'Noisy recon nl = {nl2}']\n",
        "\n",
        "dinv.utils.plot([sinogram, sinogram_noisy, sinogram_noisy2], titles_sino, figsize=(12,4))\n",
        "dinv.utils.plot([u_rec, u_rec_noisy, u_rec_noisy2], titles_rec, figsize=(12,4))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "QiRPnxb0qavI",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QiRPnxb0qavI",
        "outputId": "e45bc9f9-bacc-42f8-80d7-7ca4ad3b5d21"
      },
      "outputs": [],
      "source": [
        "# Compute the quality metrics for the reconstructed images\n",
        "print(f'Reconstruction (noiseless sinogram)')\n",
        "print(f' MSE={MSE(u_rec, u_SL_small).item():.4f}\\n PSNR={PSNR(u_rec, u_SL_small).item():.2f}\\n SSIM={SSIM(u_rec, u_SL_small).item():.2f}\\n')\n",
        "\n",
        "print(f'Reconstruction (sinogram noise lev {nl1})')\n",
        "print(f' MSE={MSE(u_rec_noisy, u_SL_small).item():.4f}\\n PSNR={PSNR(u_rec_noisy, u_SL_small).item():.2f}\\n SSIM={SSIM(u_rec_noisy, u_SL_small).item():.2f}\\n')\n",
        "\n",
        "print(f'Reconstruction (sinogram noise lev {nl2})')\n",
        "print(f' MSE={MSE(u_rec_noisy2, u_SL_small).item():.4f}\\n PSNR={PSNR(u_rec_noisy2, u_SL_small).item():.2f}\\n SSIM={SSIM(u_rec_noisy2, u_SL_small).item():.2f}\\n')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "pmADQnFJv7h6",
      "metadata": {
        "id": "pmADQnFJv7h6"
      },
      "outputs": [],
      "source": [
        "# Selection corrupted sinogram for the following experiments\n",
        "y = sinogram_noisy2"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "158fa19d-1296-4401-b517-07c5cb9d1785",
      "metadata": {
        "id": "158fa19d-1296-4401-b517-07c5cb9d1785"
      },
      "source": [
        "# Unregularized Reconstruction\n",
        "Ideally, we would like to solve the system $\\mathbf{A} u = y$. Assembling the matrix $\\mathbf{A}$ is very expensive in terms of memory.\n",
        "We can avoid assembling the matrix $\\mathbf{A}$!\n",
        "\n",
        "Iterative solvers for $\\mathbf{A} u = y$: e.g. minimizing $\\frac{1}{2}\\| \\mathbf{A} u- y\\|^2$ via **gradient method**:\n",
        "\n",
        "$$\n",
        "\\left\\{\n",
        "\\begin{aligned}\n",
        "u^{0} & \\text{ given} \\\\\n",
        "u^{(k+1)} =&\\  \\mathbf{A}^T(\\mathbf{A}u^{(k)}-y)\n",
        "\\end{aligned}\n",
        "\\right.\n",
        "$$\n",
        "\n",
        "This only requires the knowledge of $\\mathbf{A}$ and its adjoint $\\mathbf{A}^T$. The application of $\\mathbf{A}$ and $\\mathbf{A}^T$ can be done without matrices via operators, functions..."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "GtZIqKx_n8mA",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 938
        },
        "id": "GtZIqKx_n8mA",
        "outputId": "c721e402-0157-4ce8-f78b-4f7d51502aed"
      },
      "outputs": [],
      "source": [
        "# Define Fidelity\n",
        "fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Define Prior\n",
        "#prior = dinv.optim.prior.Zero()\n",
        "prior = dinv.optim.prior.ZeroPrior()\n",
        "\n",
        "# Define optimizer (check https://deepinv.github.io/deepinv/api/stubs/deepinv.optim.optim_builder.html)\n",
        "opt = dinv.optim.optim_builder() #TODO\n",
        "# Run\n",
        "u_hat, metrics = opt(y, tomo_exp, compute_metrics=True)\n",
        "\n",
        "print(f'Reconstruction experiment')\n",
        "print(f' MSE={MSE(u_hat, u_SL_small).item():.4f}\\n PSNR={PSNR(u_hat, u_SL_small).item():.2f}\\n SSIM={SSIM(u_hat, u_SL_small).item():.2f}\\n')\n",
        "\n",
        "plt.plot(metrics['cost'][0])\n",
        "dinv.utils.plot([u_hat, u_SL_small], ['Recons itr', 'GT'], figsize=(12,4))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "df0f2112-5362-449d-b375-5dbe435191f7",
      "metadata": {
        "id": "df0f2112-5362-449d-b375-5dbe435191f7"
      },
      "source": [
        "# Tikhonov regularization\n",
        "Penalizes solutions with large norms:\n",
        "\n",
        "$$ u_{\\alpha}^* = \\arg\\min_{u \\in \\mathbf{R}^n} \\left\\{ \\frac{1}{2}\\| A u - y \\|^2 + \\alpha \\| u\\|^2\\right\\}$$\n",
        "\n",
        "(the first term of the sum measures the data fidelity and might be chosen differently according to the noise model).\n",
        "\n",
        "The solution can be also found via first-order optimality conditions:\n",
        "\n",
        "$$  A^T(A u_{\\alpha} -y) + \\alpha u_{\\alpha}^* = 0 \\quad ⇒ \\quad (A^T A + \\alpha \\operatorname{Id})u_{\\alpha}^* = A^T y $$\n",
        "\n",
        "thus $u_\\alpha^*$ can be found solving a linear system with symmetric, positive definite, matrix."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "na5RuzEbZxSm",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 938
        },
        "id": "na5RuzEbZxSm",
        "outputId": "1a0feb21-fb1f-4794-d49d-3b28447523a5"
      },
      "outputs": [],
      "source": [
        "# Define Fidelity\n",
        "fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Define Prior\n",
        "prior = dinv.optim.Tikhonov()\n",
        "\n",
        "# Define optimizer\n",
        "opt = dinv.optim.optim_builder() #TODO\n",
        "## Run\n",
        "u_hat, metrics = opt(y, tomo_exp, compute_metrics=True)\n",
        "\n",
        "print(f'Reconstruction experiment')\n",
        "print(f' MSE={MSE(u_hat, u_SL_small).item():.4f}\\n PSNR={PSNR(u_hat, u_SL_small).item():.2f}\\n SSIM={SSIM(u_hat, u_SL_small).item():.2f}\\n')\n",
        "\n",
        "plt.plot(metrics['cost'][0])\n",
        "dinv.utils.plot([u_hat, u_SL_small], ['Recons itr', 'GT'], figsize=(12,4))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "cd0f8cae-5ca8-47a9-a654-c308d386782c",
      "metadata": {
        "id": "cd0f8cae-5ca8-47a9-a654-c308d386782c"
      },
      "source": [
        "# Sparsity-promoting regularization (LASSO)\n",
        "Penalizes solutions with large $1$-norms: this can be seen as the convex relaxation on the $0$-'norm' penalization, hence promoting the fact that the solution has **few pixels different from 0**\n",
        "\n",
        "$$ u_{\\alpha}^* = \\arg\\min_{u \\in \\mathbf{R}^n} \\left\\{ \\frac{1}{2}\\| Au - y \\|^2 + \\alpha \\| u\\|_1 \\right\\}$$\n",
        "\n",
        "Approximate $u_\\alpha$ via **Proximal-Gradient Descent** (**PGD**) method: the proximal of $\\alpha \\| \\cdot \\|_1$ is **soft-thresholding** $S_\\alpha(u) = \\operatorname{sign}(u) (u-\\alpha)^+$. This version of PGD is known as ISTA (Iterative Soft-Thresholding Algorithm)\n",
        "\n",
        "$$\n",
        "\\left\\{\n",
        "\\begin{aligned}\n",
        "u^{0}   & \\quad \\text{given} \\\\\n",
        "u^{k+1} &= S_{\\tau \\alpha} \\big(u^{k} - \\tau A^T(A u^{k}-y)\\big)\n",
        "\\end{aligned}\n",
        "\\right.\n",
        "$$"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "RfmMYrYybrFh",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 938
        },
        "id": "RfmMYrYybrFh",
        "outputId": "deb663bd-9e9c-4241-c43c-b3e893754d0a"
      },
      "outputs": [],
      "source": [
        "# Define Fidelity\n",
        "fidelity = dinv.optim.data_fidelity.L2()\n",
        "\n",
        "# Define Prior\n",
        "prior = dinv.optim.L1Prior()\n",
        "\n",
        "# Define optimizer\n",
        "opt = dinv.optim.optim_builder() #TODO\n",
        "## Run\n",
        "u_hat, metrics = opt(y, tomo_exp, compute_metrics=True)\n",
        "\n",
        "print(f'Reconstruction experiment')\n",
        "print(f' MSE={MSE(u_hat, u_SL_small).item():.4f}\\n PSNR={PSNR(u_hat, u_SL_small).item():.2f}\\n SSIM={SSIM(u_hat, u_SL_small).item():.2f}\\n')\n",
        "\n",
        "plt.plot(metrics['cost'][0])\n",
        "dinv.utils.plot([u_hat, u_SL_small], ['Recons itr', 'GT'], figsize=(12,4))"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "fAHflUNa7krv",
      "metadata": {
        "id": "fAHflUNa7krv"
      },
      "source": [
        "**TASK 2: Hyperparameter search**\n",
        "\n",
        "Run the algorithms above. Play with the `stepsize`, `lambda` and `max_iter` hyperparameters to see how they affect the reconstruction. Does running it for more iterations always make the image better?\n",
        "\n",
        "_Use code below to make see the differences in terms of MSE, PSNR and SSIM changing the parameters_"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "TAwDKj9Ws2eQ",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 440
        },
        "id": "TAwDKj9Ws2eQ",
        "outputId": "7b036825-c576-4f4d-d365-4847482e58c4"
      },
      "outputs": [],
      "source": [
        "MSE_list = []\n",
        "PSNR_list = []\n",
        "SSIM_list = []\n",
        "par_list = np.linspace(1e-3, 1, num=5)\n",
        "\n",
        "for par in par_list:\n",
        "    opt = dinv.optim.optim_builder(iteration=\"GD\", prior=prior, data_fidelity=fidelity, \\\n",
        "                                    max_iter=100, crit_conv='residual', thres_conv=1e-4, early_stop=True, \\\n",
        "                                    params_algo={\"stepsize\": par}\n",
        "                                  )\n",
        "    u_hat = opt(y, tomo_exp)\n",
        "\n",
        "    MSE_list.append(MSE(u_hat, u_SL_small))\n",
        "    PSNR_list.append(PSNR(u_hat, u_SL_small))\n",
        "    SSIM_list.append(SSIM(u_hat, u_SL_small))\n",
        "\n",
        "fig, axs = plt.subplots(1, 3)\n",
        "axs[0].plot(par_list, MSE_list)\n",
        "axs[1].plot(par_list, PSNR_list)\n",
        "axs[2].plot(par_list, SSIM_list)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "zcwZj6uEZgeg",
      "metadata": {
        "id": "zcwZj6uEZgeg"
      },
      "source": [
        "**Optional TASK: Construction A**\n",
        "- Assembly the matrix $\\mathbf{A}$ associated to the Radon transform (use a small `image_size` and a sparse geometry at first).\n",
        "- Compute the SVD observing the decay of its singular values."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qNauKKaTF5yH",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qNauKKaTF5yH",
        "outputId": "e6fb3004-de24-40a6-e545-ae3892970906"
      },
      "outputs": [],
      "source": [
        "N = 30\n",
        "N_angles = 30\n",
        "theta = np.linspace(0, 180, N_angles, endpoint=False)\n",
        "theta = torch.tensor(theta)\n",
        "tomoSVD = dinv.physics.Tomography(angles=theta, img_width=N, circle=circle, device=device)\n",
        "\n",
        "temp = tomoSVD(torch.zeros((1,1,N,N)))\n",
        "M = torch.numel(temp)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "qDrDu-kGHITb",
      "metadata": {
        "id": "qDrDu-kGHITb"
      },
      "outputs": [],
      "source": [
        "A_mat = #TODO\n",
        "\n",
        "for i in range(...):\n",
        "  e_i = #TODO\n",
        "  e_i[i] = 1\n",
        "  e_i = e_i.view(1,1,N,N)\n",
        "  A_mat[:, i] =  torch.reshape(...,(M,))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "3gPu0c3iLQH1",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 428
        },
        "id": "3gPu0c3iLQH1",
        "outputId": "e003db43-4fed-4a57-9edf-d5ef33168965"
      },
      "outputs": [],
      "source": [
        "dinv.utils.plot(A_mat.unsqueeze(0), ['A_mat'], figsize=(6,4))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "l6tbKtEXIVFO",
      "metadata": {
        "id": "l6tbKtEXIVFO"
      },
      "outputs": [],
      "source": [
        "S = torch.linalg.svdvals(A_mat)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "XomXRpcSMKbw",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 440
        },
        "id": "XomXRpcSMKbw",
        "outputId": "e1245cb5-d677-4a81-e793-912889157f6c"
      },
      "outputs": [],
      "source": [
        "plt.semilogy(S.numpy())\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "b1415874-bc81-4a62-a15d-645fe1925998",
      "metadata": {
        "id": "b1415874-bc81-4a62-a15d-645fe1925998"
      },
      "outputs": [],
      "source": [
        "# TASK: observe the decay of the singular values in case of Radon, limited-angle Radon, sparse-angle Radon"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "67b1281e-b7de-491a-b3b3-7572eaecc4d7"
      ],
      "provenance": []
    },
    "kernelspec": {
      "display_name": "base",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.12.7"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
