Skip to content
Snippets Groups Projects
SDE Performance plots & PFI.ipynb 5.2 MiB
Newer Older
kbb22210's avatar
kbb22210 committed
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "42fcd687-35e8-4d34-9111-03a2ceb48f21",
   "metadata": {},
   "source": [
    "# It should be noted that the values of metrics quoted in the paper may not match exactly the values herein, as there is slight variability due to retraining of the model. The results quoted in the paper are the average values found during the dataset size scaling tests with the full taining dataset. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "25a8129b-868f-4f46-9abf-c2a4ebba4b73",
   "metadata": {},
   "source": [
    "# Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d41785e6-65f3-4aea-a94e-6af5bb445f56",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-12-19 11:01:46.598996: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
      "To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import sys\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import random\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from copy import deepcopy\n",
    "import tensorflow as tf\n",
    "tf.config.threading.set_intra_op_parallelism_threads(16)\n",
    "\n",
    "from synthdiag.Ensemble import NNEnsemble\n",
    "from synthdiag.Autopredictor import Autopredictor\n",
    "from synthdiag.losses import CE_CI, PUV, sharpness"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dac8b1e5-cf4b-4514-97ec-63d146055435",
   "metadata": {},
   "source": [
    "# Import Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2303bbe1-605f-42c1-92c8-af997f4324fd",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TrainingData\n",
      "TestData\n",
      "ProtonEnergies\n",
      "BackSpectWavelengths\n"
     ]
    }
   ],
   "source": [
    "with open(\"TrainTestSplit.pkl\", 'rb') as f:\n",
    "    all_data = pickle.load(f)\n",
    "for key in all_data:\n",
    "    print(key)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "0e69b04e-6c77-4bb1-a8de-b944d8989158",
   "metadata": {},
   "outputs": [],
   "source": [
    "train = all_data['TrainingData']\n",
    "test = all_data['TestData']\n",
    "protonEnergies = all_data['ProtonEnergies']\n",
    "backSpectWavelengths = all_data['BackSpectWavelengths']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f8e01ea1-2293-4feb-b771-37587086ce41",
   "metadata": {},
   "outputs": [],
   "source": [
    "spectra_true = test['TPSpect']/np.amax(test['TPSpect'])\n",
    "flux_true = np.trapz(spectra_true, x=protonEnergies, axis=-1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "55997164-0c88-4372-a84d-976b72f7c789",
   "metadata": {},
   "source": [
    "# Define inputs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4160ee0a-7e9f-4dfa-9317-c64707b0c7e1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['energy', 'pulseduration', 'prepulsedelay', 'defocus', 'waist', 'encircled_energy', 'BackSpectCM_1', 'BackRfl1wCM_1', 'BackRfl2wCM_1', 'BackSpectCMu_2', 'BackRfl1wCMu_2', 'BackRfl2wCMu_2', 'BackSpectCMu_3', 'BackRfl1wCMu_3', 'BackRfl2wCMu_3']\n"
     ]
    }
   ],
   "source": [
    "diags = ['energy', 'pulseduration', 'prepulsedelay', 'defocus', 'waist', 'encircled_energy']\n",
    "\n",
    "## Add backreflection moments to diags up to specified order\n",
    "order = 3\n",
    "for i in range(order):\n",
    "    if i>order:\n",
    "        continue\n",
    "    if i==0:\n",
    "        moment1d_string = f'CM_{i+1}'\n",
    "    else:\n",
    "        moment1d_string = f'CMu_{i+1}'\n",
    "    diags.append('BackSpect'+moment1d_string)\n",
    "    diags.append('BackRfl1w'+moment1d_string)\n",
    "    diags.append('BackRfl2w'+moment1d_string)\n",
    "print(diags)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a38bf031-2558-41be-aed2-442c6ef9bf99",
   "metadata": {},
   "source": [
    "## Load ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1df8da30-f7c1-474c-9974-0ea4dde5f8c0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
      "Loading beta-VAE weights\n",
      "Loading NN weights\n",
Loading
Loading full blame...