Skip to content

Commit

Permalink
Export 1% relerr
Browse files Browse the repository at this point in the history
  • Loading branch information
larsson4 committed May 24, 2024
1 parent 42ceac6 commit 382f1e5
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions utils/python/linelast_cwtrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: MIT

import numpy as np
from scipy.interpolate import interp1d
import h5py
import os
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -153,6 +154,7 @@ def get_results(samples, prefix):
def solve_time_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"):
plt.plot(samples, res[:,0], label='FOM')
plt.plot(samples, res[:,1], label='ROM')
plt.xscale('log')
plt.xlabel(scale_prefix)
plt.yscale('log')
plt.ylabel("Solve time [s]")
Expand All @@ -164,6 +166,7 @@ def solve_time_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"

def relerr_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"):
plt.plot(samples, res[:,2], label='Relative error')
plt.xscale('log')
plt.xlabel(scale_prefix)
plt.yscale('log')
plt.ylabel("Relative error [-]")
Expand All @@ -174,18 +177,34 @@ def relerr_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"):

def speedup_scaling_plot(samples, res, scale_prefix, plt_name = "scaling.png"):
plt.plot(samples, res[:,3])
plt.xscale('log')
plt.xlabel(scale_prefix)
plt.ylabel("Speedup factor [-]")

plt.tight_layout()
plt.savefig("speedup_" + plt_name, dpi=300)
plt.clf()

def export_opt_val(filename, opt_nb_a, opt_nb_i, opt_speedup):
f = open(filename, "w")
out_txt = "Optimal number of bases (interpolated) is: " + str(round(float(opt_nb_a), 4)) + " bases\nOptimal number of bases (rounded) is: " + str(opt_nb_i) + " bases\nSpeedup at rounded number of bases is: " + str(round(float(opt_speedup), 4)) + " x"
f.write(out_txt)
f.close()

def create_scaling_plot(samples, res, scale_prefix, plt_name = "plot.png"):
plt.rc('axes', labelsize=14)
ferr = interp1d(samples, res[:,2]) # Rel err
ferr_i = interp1d(res[:,2], samples) # Inverse correlation
x_star_a = ferr_i(1e-2) # Analytical x_star
x_star_i = np.ceil(x_star_a) # Next integer

fspeed = interp1d(samples, res[:,3]) # speedup factor
opt_speedup = fspeed(x_star_i)

solve_time_scaling_plot(samples, res, scale_prefix, plt_name)
relerr_scaling_plot(samples, res, scale_prefix, plt_name)
speedup_scaling_plot(samples, res, scale_prefix, plt_name)
export_opt_val("opt_vals.txt", x_star_a, x_star_i, opt_speedup)

def get_nr(txt, split_txt = 'comparison'):
return int(txt.split('.')[0].split(split_txt)[1])
Expand Down

0 comments on commit 382f1e5

Please sign in to comment.