Skip to content

Commit

Permalink
Implement phase unwrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
AngelFP committed Nov 15, 2023
1 parent cc774aa commit 70d3ef2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
3 changes: 2 additions & 1 deletion wake_t/physics_models/laser/envelope_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from wake_t.utilities.numba import njit_serial
from .tdma import TDMA
from .utils import unwrap


@njit_serial(fastmath=True)
Expand Down Expand Up @@ -91,7 +92,7 @@ def evolve_envelope(

# Getting the phase of the envelope on axis.
if use_phase:
phases = np.angle(a[:, 0])
phases = unwrap(np.angle(a[:, 0]))

# Loop over z.
for j in range(nz - 1, -1, -1):
Expand Down
3 changes: 2 additions & 1 deletion wake_t/physics_models/laser/envelope_solver_non_centered.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from wake_t.utilities.numba import njit_serial
from .tdma import TDMA
from .utils import unwrap


@njit_serial(fastmath=True)
Expand Down Expand Up @@ -91,7 +92,7 @@ def evolve_envelope_non_centered(

# Getting the phase of the envelope on axis.
if use_phase:
phases = np.angle(a[:, 0])
phases = unwrap(np.angle(a[:, 0]))

# Loop over z.
for j in range(nz - 1, -1, -1):
Expand Down
59 changes: 59 additions & 0 deletions wake_t/physics_models/laser/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Utilities for the laser envelope solver."""
import numpy as np
from wake_t.utilities.numba import njit_serial


@njit_serial
def unwrap(p, discont=None, axis=-1, period=6.283185307179586):
"""Numba version of numpy.unwrap.
The implementation is taken from
https://github.com/numba/numba/blob/main/numba/np/arraymath.py,
which currently is not yet included in the latest Numba release.
"""
if axis != -1:
msg = 'Value for argument "axis" is not supported'
raise ValueError(msg)
# Flatten to a 2D array, keeping axis -1
p_init = np.asarray(p).astype(np.float64)
init_shape = p_init.shape
last_axis = init_shape[-1]
p_new = p_init.reshape((p_init.size // last_axis, last_axis))
# Manipulate discont and period
if discont is None:
discont = period / 2
interval_high = period / 2
boundary_ambiguous = True
interval_low = -interval_high

slice1 = (slice(1, None, None),)

# Work on each row separately
for i in range(p_init.size // last_axis):
row = p_new[i]
dd = np.diff(row)
ddmod = np.mod(dd - interval_low, period) + interval_low
if boundary_ambiguous:
ddmod = np.where(
(ddmod == interval_low) & (dd > 0),
interval_high,
ddmod
)
ph_correct = ddmod - dd

ph_correct = np.where(
np.array([abs(x) for x in dd]) < discont,
0,
ph_correct
)
ph_ravel = np.where(
np.array([abs(x) for x in dd]) < discont,
0,
ph_correct
)
ph_correct = np.reshape(ph_ravel, ph_correct.shape)
up = np.copy(row)
up[slice1] = row[slice1] + ph_correct.cumsum()
p_new[i] = up

return p_new.reshape(init_shape)

0 comments on commit 70d3ef2

Please sign in to comment.