|
1 | 1 | { |
2 | 2 | "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Fast pixel uncertainty quantification (UQ) example with the `QuantifAI` model\n", |
| 8 | + "\n", |
| 9 | + "In this notebook we:\n", |
| 10 | + "- set hyperparameters,\n", |
| 11 | + "- prepare the synthetic observations,\n", |
| 12 | + "- define the model, likelihood and prior,\n", |
| 13 | + "- estimate the MAP reconstruction through a convex optimisation algorithm,\n", |
| 14 | + "- compute the fast pixel UQ maps,\n", |
| 15 | + "- compare the predicted error maps and the true error maps.\n", |
| 16 | + "\n" |
| 17 | + ] |
| 18 | + }, |
3 | 19 | { |
4 | 20 | "cell_type": "code", |
5 | 21 | "execution_count": 5, |
|
24 | 40 | "\n", |
25 | 41 | "import time as time\n", |
26 | 42 | "\n", |
27 | | - "from functools import partial\n", |
28 | | - "\n", |
29 | 43 | "# Import torch and select GPU\n", |
30 | 44 | "import torch\n", |
31 | 45 | "\n", |
32 | | - "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"2\"\n", |
| 46 | + "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n", |
33 | 47 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |
34 | 48 | "if torch.cuda.is_available():\n", |
35 | 49 | " print(torch.cuda.is_available())\n", |
|
559 | 573 | "outputs": [], |
560 | 574 | "source": [] |
561 | 575 | }, |
562 | | - { |
563 | | - "cell_type": "code", |
564 | | - "execution_count": 21, |
565 | | - "metadata": {}, |
566 | | - "outputs": [ |
567 | | - { |
568 | | - "data": { |
569 | | - "text/plain": [ |
570 | | - "65536" |
571 | | - ] |
572 | | - }, |
573 | | - "execution_count": 21, |
574 | | - "metadata": {}, |
575 | | - "output_type": "execute_result" |
576 | | - } |
577 | | - ], |
578 | | - "source": [ |
579 | | - "to_numpy(x_map).size" |
580 | | - ] |
581 | | - }, |
582 | | - { |
583 | | - "cell_type": "code", |
584 | | - "execution_count": 15, |
585 | | - "metadata": {}, |
586 | | - "outputs": [], |
587 | | - "source": [ |
588 | | - "# Function handle for the potential\n", |
589 | | - "def _fun(_x, CRR_model, mu, lmbd):\n", |
590 | | - " return (lmbd / mu) * CRR_model.cost(mu * _x) + likelihood.fun(_x)\n", |
591 | | - "\n", |
592 | | - "\n", |
593 | | - "# Evaluation of the potential\n", |
594 | | - "fun = partial(_fun, CRR_model=CRR_model, mu=mu, lmbd=lmbd)\n", |
595 | | - "# Evaluation of the potential in numpy\n", |
596 | | - "fun_np = lambda _x: fun(qai.utils.to_tensor(_x, dtype=myType)).item()\n", |
597 | | - "\n", |
598 | | - "# Compute HPD region bound\n", |
599 | | - "N = np_x_map.size\n", |
600 | | - "tau_alpha = np.sqrt(16 * np.log(3 / alpha_prob))\n", |
601 | | - "gamma_alpha = fun(x_map).item() + tau_alpha * np.sqrt(N) + N\n", |
602 | | - "\n", |
603 | | - "# Define the wavelet dict\n", |
604 | | - "# Define the l1 norm with dict psi\n", |
605 | | - "Psi = qai.operators.DictionaryWv_torch(wavs_list, levels)\n", |
606 | | - "oper2wavelet = qai.operators.Operation2WaveletCoeffs_torch(Psi=Psi)\n", |
607 | | - "\n", |
608 | | - "# Clone MAP estimation and cast type for wavelet operations\n", |
609 | | - "torch_map = torch.clone(x_map).to(torch.float64)\n", |
610 | | - "torch_x = to_tensor(x_gt).to(torch.float64)" |
611 | | - ] |
612 | | - }, |
613 | | - { |
614 | | - "cell_type": "code", |
615 | | - "execution_count": 16, |
616 | | - "metadata": {}, |
617 | | - "outputs": [ |
618 | | - { |
619 | | - "name": "stderr", |
620 | | - "output_type": "stream", |
621 | | - "text": [ |
622 | | - "/disk/xray0/tl3/repos/QuantifAI/quantifai/utils.py:48: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", |
623 | | - " return torch.tensor(z, device=device, dtype=dtype, requires_grad=False).reshape(\n" |
624 | | - ] |
625 | | - } |
626 | | - ], |
627 | | - "source": [ |
628 | | - "def _potential_to_bisect(thresh, fun_np, oper2wavelet, torch_map):\n", |
629 | | - " thresh_img = oper2wavelet.full_op_threshold_img(\n", |
630 | | - " torch_map, thresh, thresh_type=\"hard\"\n", |
631 | | - " )\n", |
632 | | - "\n", |
633 | | - " return gamma_alpha - fun_np(thresh_img)\n", |
634 | | - "\n", |
635 | | - "\n", |
636 | | - "# Evaluation of the potential\n", |
637 | | - "potential_to_bisect = partial(\n", |
638 | | - " _potential_to_bisect, fun_np=fun_np, oper2wavelet=oper2wavelet, torch_map=torch_map\n", |
639 | | - ")\n", |
640 | | - "\n", |
641 | | - "\n", |
642 | | - "selected_thresh, bisec_iters = qai.map_uncertainty.bisection_method(\n", |
643 | | - " potential_to_bisect, start_interval, iters, tol, return_iters=True\n", |
644 | | - ")\n", |
645 | | - "select_thresh_img = oper2wavelet.full_op_threshold_img(torch_map, selected_thresh)" |
646 | | - ] |
647 | | - }, |
648 | | - { |
649 | | - "cell_type": "code", |
650 | | - "execution_count": 17, |
651 | | - "metadata": {}, |
652 | | - "outputs": [ |
653 | | - { |
654 | | - "name": "stdout", |
655 | | - "output_type": "stream", |
656 | | - "text": [ |
657 | | - "SNR (thresh vs MAP) at lvl 0: 20.090000\n", |
658 | | - "SNR (MAP vs GT) at lvl 0: 43.690000\n", |
659 | | - "SNR (thresh vs MAP) at lvl 1: 17.080000\n", |
660 | | - "SNR (MAP vs GT) at lvl 1: 39.120000\n", |
661 | | - "SNR (thresh vs MAP) at lvl 2: 16.550000\n", |
662 | | - "SNR (MAP vs GT) at lvl 2: 32.000000\n", |
663 | | - "SNR (thresh vs MAP) at lvl 3: 22.430000\n", |
664 | | - "SNR (MAP vs GT) at lvl 3: 31.100000\n", |
665 | | - "SNR (thresh vs MAP) at lvl 4: 41.220000\n", |
666 | | - "SNR (MAP vs GT) at lvl 4: 36.690000\n" |
667 | | - ] |
668 | | - } |
669 | | - ], |
670 | | - "source": [ |
671 | | - "modif_img_list = []\n", |
672 | | - "GT_modif_img_list = []\n", |
673 | | - "SNR_at_lvl_list = []\n", |
674 | | - "SNR_at_lvl_map_vs_GT_list = []\n", |
675 | | - "\n", |
676 | | - "for modif_level in range(levels + 1):\n", |
677 | | - " op = lambda x1, x2: x2\n", |
678 | | - "\n", |
679 | | - " modif_img = oper2wavelet.full_op_two_img(\n", |
680 | | - " torch.clone(torch_map), torch.clone(select_thresh_img), op, level=modif_level\n", |
681 | | - " )\n", |
682 | | - " GT_modif_img = oper2wavelet.full_op_two_img(\n", |
683 | | - " torch.clone(torch_x), torch.clone(torch_map), op, level=modif_level\n", |
684 | | - " )\n", |
685 | | - " print(\n", |
686 | | - " \"SNR (thresh vs MAP) at lvl {:d}: {:f}\".format(\n", |
687 | | - " modif_level, qai.utils.eval_snr(to_numpy(torch_map), to_numpy(modif_img))\n", |
688 | | - " )\n", |
689 | | - " )\n", |
690 | | - " print(\n", |
691 | | - " \"SNR (MAP vs GT) at lvl {:d}: {:f}\".format(\n", |
692 | | - " modif_level, qai.utils.eval_snr(to_numpy(torch_x), to_numpy(GT_modif_img))\n", |
693 | | - " )\n", |
694 | | - " )\n", |
695 | | - " modif_img_list.append(to_numpy(modif_img))\n", |
696 | | - " GT_modif_img_list.append(to_numpy(GT_modif_img))\n", |
697 | | - " SNR_at_lvl_list.append(qai.utils.eval_snr(to_numpy(torch_map), to_numpy(modif_img)))\n", |
698 | | - " SNR_at_lvl_map_vs_GT_list.append(\n", |
699 | | - " qai.utils.eval_snr(to_numpy(torch_x), to_numpy(GT_modif_img))\n", |
700 | | - " )" |
701 | | - ] |
702 | | - }, |
703 | | - { |
704 | | - "cell_type": "code", |
705 | | - "execution_count": 4, |
706 | | - "metadata": {}, |
707 | | - "outputs": [ |
708 | | - { |
709 | | - "data": { |
710 | | - "text/plain": [ |
711 | | - "torch.dtype" |
712 | | - ] |
713 | | - }, |
714 | | - "execution_count": 4, |
715 | | - "metadata": {}, |
716 | | - "output_type": "execute_result" |
717 | | - } |
718 | | - ], |
719 | | - "source": [ |
720 | | - "type(torch.float)" |
721 | | - ] |
722 | | - }, |
723 | | - { |
724 | | - "cell_type": "code", |
725 | | - "execution_count": null, |
726 | | - "metadata": {}, |
727 | | - "outputs": [], |
728 | | - "source": [] |
729 | | - }, |
730 | 576 | { |
731 | 577 | "cell_type": "code", |
732 | 578 | "execution_count": null, |
|
0 commit comments