Skip to content

Commit

Permalink
Merge pull request #139 from AngelFP/default_boris
Browse files Browse the repository at this point in the history
Make `boris` the default bunch pusher and enable parallelism
  • Loading branch information
AngelFP authored Oct 17, 2023
2 parents 69cc2dd + 7e2c1fe commit cd16ef5
Show file tree
Hide file tree
Showing 22 changed files with 118 additions and 69 deletions.
4 changes: 2 additions & 2 deletions tests/test_active_plasma_lens.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_active_plasma_lens():
apl.track(bunch)
bunch_params = analyze_bunch(bunch)
gamma_x = bunch_params['gamma_x']
assert approx(gamma_x, rel=1e-10) == 92.14017315271572
assert approx(gamma_x, rel=1e-10) == 92.38646379897074


def test_active_plasma_lens_with_wakefields():
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_active_plasma_lens_with_wakefields():
# Analyze and check results.
bunch_params = analyze_bunch(bunch)
gamma_x = bunch_params['gamma_x']
assert approx(gamma_x, rel=1e-10) == 77.32004939154773
assert approx(gamma_x, rel=1e-10) == 77.31995824746237


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion tests/test_field_element.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_field_element_tracking():
# Check that results have not changed.
bunch_params = analyze_bunch(bunch)
beta_x = bunch_params['beta_x']
assert approx(beta_x, rel=1e-10) == 0.054508554263608434
assert approx(beta_x, rel=1e-10) == 0.05450127309723603


def test_field_element_error():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_fluid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_fluid_model(plot=False):

# Check final parameters.
ene_sp = params_evolution['rel_ene_spread'][-1]
assert approx(ene_sp, rel=1e-10) == 0.024183646993930535
assert approx(ene_sp, rel=1e-10) == 0.024179998095119972

# Quick plot of results.
if plot:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_multibunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ def test_multibunch_plasma_simulation(plot=False):
# Assert final parameters are correct.
final_energy_driver = driver_params['avg_ene'][-1]
final_energy_witness = witness_params['avg_ene'][-1]
assert approx(final_energy_driver, rel=1e-10) == 1700.33213311266
assert approx(final_energy_witness, rel=1e-10) == 636.3355022503769
assert approx(final_energy_driver, rel=1e-10) == 1700.3927190416732
assert approx(final_energy_witness, rel=1e-10) == 636.330857261392

if plot:
z = driver_params['prop_dist'] * 1e2
Expand Down
4 changes: 2 additions & 2 deletions tests/test_ramps.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_downramp():
downramp.track(bunch)
bunch_params = analyze_bunch(bunch)
beta_x = bunch_params['beta_x']
assert beta_x == 0.009757682933057094
assert beta_x == 0.009750309290619276


def test_upramp():
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_upramp():
downramp.track(bunch)
bunch_params = analyze_bunch(bunch)
beta_x = bunch_params['beta_x']
assert beta_x == 0.0007641796148894145
assert beta_x == 0.0007631600676104024


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions wake_t/beamline_elements/active_plasma_lens.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" This module contains the definition of the ActivePlasmaLens class """

from typing import Optional, Union, Callable
from typing import Optional, Union, Callable, Literal

import numpy as np
import scipy.constants as ct
Expand Down Expand Up @@ -76,7 +76,7 @@ def __init__(
wakefields: bool = False,
density: Optional[Union[float, Callable[[float], float]]] = None,
wakefield_model: Optional[str] = 'quasistatic_2d',
bunch_pusher: Optional[str] = 'rk4',
bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris',
dt_bunch: Optional[DtBunchType] = 'auto',
n_out: Optional[int] = 1,
name: Optional[str] = 'Active plasma lens',
Expand Down
4 changes: 2 additions & 2 deletions wake_t/beamline_elements/field_element.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, List
from typing import Optional, Union, List, Literal

import scipy.constants as ct

Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
self,
length: float,
dt_bunch: Union[float, str, List[Union[float, str]]],
bunch_pusher: Optional[str] = 'rk4',
bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris',
n_out: Optional[int] = 1,
name: Optional[str] = 'field element',
fields: Optional[List[Field]] = [],
Expand Down
4 changes: 2 additions & 2 deletions wake_t/beamline_elements/plasma_ramp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

from typing import Optional, Union, Callable
from typing import Optional, Union, Callable, Literal
from functools import partial

import numpy as np
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(
plasma_dens_top: Optional[float] = None,
plasma_dens_down: Optional[float] = None,
position_down: Optional[float] = None,
bunch_pusher: Optional[str] = 'rk4',
bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris',
dt_bunch: Optional[DtBunchType] = 'auto',
n_out: Optional[int] = 1,
name: Optional[str] = 'Plasma ramp',
Expand Down
4 changes: 2 additions & 2 deletions wake_t/beamline_elements/plasma_stage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""" This module contains the definition of the PlasmaStage class """

from typing import Optional, Union, Callable, List
from typing import Optional, Union, Callable, List, Literal

import numpy as np
import scipy.constants as ct
Expand Down Expand Up @@ -75,7 +75,7 @@ def __init__(
length: float,
density: Union[float, Callable[[float], float]],
wakefield_model: Optional[str] = 'simple_blowout',
bunch_pusher: Optional[str] = 'rk4',
bunch_pusher: Optional[Literal['boris', 'rk4']] = 'boris',
dt_bunch: Optional[DtBunchType] = 'auto',
n_out: Optional[int] = 1,
name: Optional[str] = 'Plasma stage',
Expand Down
17 changes: 9 additions & 8 deletions wake_t/fields/analytical_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from .base import Field
from wake_t.utilities.numba import njit_serial
from wake_t.utilities.numba import njit_parallel


# Define type alias.
Expand Down Expand Up @@ -49,9 +49,10 @@ class AnalyticalField(Field):
Examples
--------
>>> from numba import prange
>>> def linear_ex(x, y, z, t, ex, constants):
... ex_slope = constants[0]
... for i in range(x.shape[0]):
... for i in prange(x.shape[0]):
... ex[i] = ex_slope * x[i]
...
>>> ex = AnalyticField(e_x=linear_ex, constants=[1e6])
Expand All @@ -76,12 +77,12 @@ def no_field(x, y, z, t, fld, k):
"""Default field component."""
pass

self.__e_x = njit_serial(e_x) if e_x is not None else no_field
self.__e_y = njit_serial(e_y) if e_y is not None else no_field
self.__e_z = njit_serial(e_z) if e_z is not None else no_field
self.__b_x = njit_serial(b_x) if b_x is not None else no_field
self.__b_y = njit_serial(b_y) if b_y is not None else no_field
self.__b_z = njit_serial(b_z) if b_z is not None else no_field
self.__e_x = njit_parallel(e_x) if e_x is not None else no_field
self.__e_y = njit_parallel(e_y) if e_y is not None else no_field
self.__e_z = njit_parallel(e_z) if e_z is not None else no_field
self.__b_x = njit_parallel(b_x) if b_x is not None else no_field
self.__b_y = njit_parallel(b_y) if b_y is not None else no_field
self.__b_z = njit_parallel(b_z) if b_z is not None else no_field
self.constants = np.array(constants)

def _pre_gather(self, x, y, z, t):
Expand Down
20 changes: 14 additions & 6 deletions wake_t/fields/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from wake_t.utilities.numba import njit_parallel, prange
from .base import Field


Expand Down Expand Up @@ -48,13 +49,20 @@ def gather_fields(
1D array where the gathered Bz values will be stored
"""
# Initially, set all field values to zero.
ex[:] = 0.
ey[:] = 0.
ez[:] = 0.
bx[:] = 0.
by[:] = 0.
bz[:] = 0.
reset_particle_fields(ex, ey, ez, bx, by, bz)

# Gather contributions from all fields.
for field in fields:
field.gather(x, y, z, t, ex, ey, ez, bx, by, bz)


@njit_parallel
def reset_particle_fields(ex, ey, ez, bx, by, bz):
"""Set bunch field arrays to zero."""
for i in prange(ex.size):
ex[i] = 0.
ey[i] = 0.
ez[i] = 0.
bx[i] = 0.
by[i] = 0.
bz[i] = 0.
6 changes: 3 additions & 3 deletions wake_t/particles/deposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import math
import numpy as np

from wake_t.utilities.numba import njit_serial
from wake_t.utilities.numba import njit_serial, prange


def deposit_3d_distribution(z, x, y, w, z_min, r_min, nz, nr, dz, dr,
Expand Down Expand Up @@ -83,7 +83,7 @@ def deposit_3d_distribution_linear(z, x, y, q, z_min, r_min, nz, nr, dz, dr,
r_max = nr * dr

# Loop over particles.
for i in range(z.shape[0]):
for i in prange(z.shape[0]):
# Get particle components.
x_i = x[i]
y_i = y[i]
Expand Down Expand Up @@ -171,7 +171,7 @@ def deposit_3d_distribution_cubic(z, x, y, q, z_min, r_min, nz, nr, dz, dr,
r_max = nr * dr

# Loop over particles.
for i in range(z.shape[0]):
for i in prange(z.shape[0]):
# Get particle components.
x_i = x[i]
y_i = y[i]
Expand Down
10 changes: 5 additions & 5 deletions wake_t/particles/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
import math
import numpy as np

from wake_t.utilities.numba import njit_serial
from wake_t.utilities.numba import njit_serial, njit_parallel, prange


@njit_serial()
@njit_parallel()
def gather_field_cyl_linear(fld, z_min, z_max, r_min, r_max, dz, dr, x, y, z):
"""
Interpolate a 2D field defined on an r-z grid to the particle positions
Expand Down Expand Up @@ -45,7 +45,7 @@ def gather_field_cyl_linear(fld, z_min, z_max, r_min, r_max, dz, dr, x, y, z):
fld_part = np.zeros(n_part)

# Iterate over all particles.
for i in range(n_part):
for i in prange(n_part):
# Get particle position.
x_i = x[i]
y_i = y[i]
Expand Down Expand Up @@ -87,7 +87,7 @@ def gather_field_cyl_linear(fld, z_min, z_max, r_min, r_max, dz, dr, x, y, z):
return fld_part


@njit_serial()
@njit_parallel()
def gather_main_fields_cyl_linear(
er, ez, bt, z_min, z_max, r_min, r_max, dz, dr, x, y, z,
ex_part, ey_part, ez_part, bx_part, by_part, bz_part):
Expand Down Expand Up @@ -117,7 +117,7 @@ def gather_main_fields_cyl_linear(
n_part = x.shape[0]

# Iterate over all particles.
for i in range(n_part):
for i in prange(n_part):

# Get particle position.
x_i = x[i]
Expand Down
5 changes: 5 additions & 0 deletions wake_t/particles/particle_bunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,11 @@ def evolve(self, fields, t, dt, pusher='rk4'):
apply_rk4_pusher(self, fields, t, dt)
elif pusher == 'boris':
apply_boris_pusher(self, fields, t, dt)
else:
raise ValueError(
f"Bunch pusher '{pusher}' not recognized. "
"Possible values are 'boris' and 'rk4'"
)
self.prop_distance += dt * ct.c

def get_field_arrays(self):
Expand Down
11 changes: 6 additions & 5 deletions wake_t/particles/push/boris_pusher.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
"""
import numpy as np
import scipy.constants as ct
from numba import prange

from wake_t.utilities.numba import njit_serial
from wake_t.utilities.numba import njit_parallel
from wake_t.fields.gather import gather_fields


Expand Down Expand Up @@ -42,9 +43,9 @@ def apply_boris_pusher(bunch, fields, t, dt):
bunch.x, bunch.y, bunch.xi, bunch.px, bunch.py, bunch.pz, dt)


@njit_serial()
@njit_parallel()
def apply_half_position_push(x, y, xi, px, py, pz, dt):
for i in range(x.shape[0]):
for i in prange(x.shape[0]):
# Get particle momentum
px_i = px[i]
py_i = py[i]
Expand All @@ -58,11 +59,11 @@ def apply_half_position_push(x, y, xi, px, py, pz, dt):
xi[i] += 0.5 * (pz_i * c_over_gamma_i - ct.c) * dt


@njit_serial()
@njit_parallel()
def push_momentum(px, py, pz, ex, ey, ez, bx, by, bz, dt, q_over_mc):
k = q_over_mc * dt / 2

for i in range(px.shape[0]):
for i in prange(px.shape[0]):
# Get particle momentum and fields.
px_i = px[i]
py_i = py[i]
Expand Down
27 changes: 14 additions & 13 deletions wake_t/particles/push/runge_kutta_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

import math
import scipy.constants as ct
from numba import prange

from wake_t.utilities.numba import njit_serial
from wake_t.utilities.numba import njit_parallel
from wake_t.fields.gather import gather_fields


Expand Down Expand Up @@ -117,40 +118,40 @@ def apply_rk4_pusher(bunch, fields, t, dt):
apply_push(bunch.pz, dt, dpz)


@njit_serial()
@njit_parallel()
def initialize_coord(x, x_0):
for i in range(x.shape[0]):
for i in prange(x.shape[0]):
x[i] = x_0[i]


@njit_serial()
@njit_parallel()
def update_coord(x, x_0, dt, k_x, fac):
for i in range(x.shape[0]):
for i in prange(x.shape[0]):
x[i] = x_0[i] + dt * k_x[i] * fac


@njit_serial()
@njit_parallel()
def initialize_push(dx, k_x, fac):
for i in range(dx.shape[0]):
for i in prange(dx.shape[0]):
dx[i] = k_x[i] * fac


@njit_serial()
@njit_parallel()
def update_push(dx, k_x, fac):
for i in range(dx.shape[0]):
for i in prange(dx.shape[0]):
dx[i] += k_x[i] * fac


@njit_serial()
@njit_parallel()
def apply_push(x, dt, dx):
for i in range(x.shape[0]):
for i in prange(x.shape[0]):
x[i] += dt * dx[i]


@njit_serial(fastmath=True, error_model='numpy')
@njit_parallel(fastmath=True, error_model='numpy')
def calculate_k(k_x, k_y, k_xi, k_px, k_py, k_pz,
q_over_mc, px, py, pz, ex, ey, ez, bx, by, bz):
for i in range(k_x.shape[0]):
for i in prange(k_x.shape[0]):
px_i = px[i]
py_i = py[i]
pz_i = pz[i]
Expand Down
Loading

0 comments on commit cd16ef5

Please sign in to comment.