From 97b43999ded66610b17cbee4393276773a04817f Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 14:52:50 +0200 Subject: [PATCH 01/14] Combined kicker Implementing a corrector with kicks in horizontal and vertical planes --- cheetah/__init__.py | 1 + cheetah/accelerator/__init__.py | 1 + cheetah/accelerator/corrector.py | 175 +++++++++++++++++++++++++++++++ tests/test_correctors.py | 72 +++++++++++++ 4 files changed, 249 insertions(+) create mode 100644 cheetah/accelerator/corrector.py create mode 100644 tests/test_correctors.py diff --git a/cheetah/__init__.py b/cheetah/__init__.py index 37a36995..b13e1e57 100644 --- a/cheetah/__init__.py +++ b/cheetah/__init__.py @@ -3,6 +3,7 @@ BPM, Aperture, Cavity, + Corrector, CustomTransferMap, Dipole, Drift, diff --git a/cheetah/accelerator/__init__.py b/cheetah/accelerator/__init__.py index 3aa65fb3..a27dc438 100644 --- a/cheetah/accelerator/__init__.py +++ b/cheetah/accelerator/__init__.py @@ -1,6 +1,7 @@ from .aperture import Aperture # noqa: F401 from .bpm import BPM # noqa: F401 from .cavity import Cavity # noqa: F401 +from .corrector import Corrector # noqa: F401 from .custom_transfer_map import CustomTransferMap # noqa: F401 from .dipole import Dipole # noqa: F401 from .drift import Drift # noqa: F401 diff --git a/cheetah/accelerator/corrector.py b/cheetah/accelerator/corrector.py new file mode 100644 index 00000000..58455f18 --- /dev/null +++ b/cheetah/accelerator/corrector.py @@ -0,0 +1,175 @@ +from typing import Optional, Union + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.patches import Rectangle +from scipy.constants import physical_constants +from torch import Size, nn + +from cheetah.utils import UniqueNameGenerator + +from .element import Element + +generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") + +electron_mass_eV = torch.tensor( + physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 +) + + +class Corrector(Element): + """ + Corrector magnet in a particle accelerator. + Note: This is modeled as a drift section with + a thin-kick in the horizontal plane followed by + a thin-kick in the vertical plane. + + :param length: Length in meters. + :param horizontal_angle: Particle deflection horizontal_angle in + the horizontal plane in rad. + :param vertical_angle: Particle deflection vertical_angle in + the vertical plane in rad. + :param name: Unique identifier of the element. + """ + + def __init__( + self, + length: Union[torch.Tensor, nn.Parameter], + horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + name: Optional[str] = None, + device=None, + dtype=torch.float32, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__(name=name) + + self.length = torch.as_tensor(length, **factory_kwargs) + self.horizontal_angle = ( + torch.as_tensor(horizontal_angle, **factory_kwargs) + if horizontal_angle is not None + else torch.zeros_like(self.length) + ) + self.vertical_angle = ( + torch.as_tensor(vertical_angle, **factory_kwargs) + if vertical_angle is not None + else torch.zeros_like(self.length) + ) + + def horizontal_transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + + gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) + igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? + igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 + beta = torch.sqrt(1 - igamma2) + + h_tm = torch.eye(7, device=device, dtype=dtype).repeat( + (*self.length.shape, 1, 1) + ) + h_tm[..., 0, 1] = self.length + h_tm[..., 1, 6] = self.horizontal_angle + h_tm[..., 2, 3] = self.length + h_tm[..., 4, 5] = -self.length / beta**2 * igamma2 + + # print(h_tm) + + return h_tm + + def vertical_transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + + gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) + igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? + igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 + beta = torch.sqrt(1 - igamma2) + + v_tm = torch.eye(7, device=device, dtype=dtype).repeat( + (*self.length.shape, 1, 1) + ) + v_tm[..., 0, 1] = self.length + v_tm[..., 2, 3] = self.length + v_tm[..., 3, 6] = self.vertical_angle + v_tm[..., 4, 5] = -self.length / beta**2 * igamma2 + + # print(v_tm) + + return v_tm + + def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: + device = self.length.device + dtype = self.length.dtype + + gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) + igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? + igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 + beta = torch.sqrt(1 - igamma2) + + tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) + tm[..., 0, 1] = self.length + tm[..., 2, 3] = self.length + tm[..., 1, 6] = self.horizontal_angle + tm[..., 3, 6] = self.vertical_angle + tm[..., 4, 5] = -self.length / beta**2 * igamma2 + + return tm + + def broadcast(self, shape: Size) -> Element: + return self.__class__( + length=self.length.repeat(shape), + horizontal_angle=self.horizontal_angle, + vertical_angle=self.vertical_angle, + name=self.name, + ) + + @property + def is_skippable(self) -> bool: + return True + + @property + def is_active(self) -> bool: + return any(self.horizontal_angle != 0, self.vertical_angle != 0) + + def split(self, resolution: torch.Tensor) -> list[Element]: + split_elements = [] + remaining = self.length + while remaining > 0: + length = torch.min(resolution, remaining) + element = Corrector( + length, + self.horizontal_angle * length / self.length, + self.vertical_angle * length / self.length, + ) + split_elements.append(element) + remaining -= resolution + return split_elements + + def plot(self, ax: plt.Axes, s: float) -> None: + alpha = 1 if self.is_active else 0.2 + height = (np.sign(self.horizontal_angle[0]) if self.is_active else 1) * ( + np.sign(self.vertical_angle[0]) if self.is_active else 1 + ) + + patch = Rectangle( + (s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2 + ) + ax.add_patch(patch) + + @property + def defining_features(self) -> list[str]: + return super().defining_features + [ + "length", + "horizontal_angle", + "vertical_angle", + ] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"horizontal_angle={repr(self.horizontal_angle)}, " + + f"vertical_angle={repr(self.vertical_angle)}, " + + f"name={repr(self.name)})" + ) diff --git a/tests/test_correctors.py b/tests/test_correctors.py new file mode 100644 index 00000000..56da7eed --- /dev/null +++ b/tests/test_correctors.py @@ -0,0 +1,72 @@ +import torch + +from cheetah import Corrector, Drift, ParameterBeam, ParticleBeam, Segment + + +def test_corrector_off(): + """ + Test that a corrector with horizontal_angle=0 and vertical_angle=0 behaves + still like a drift. + """ + corrector = Corrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([0.0]), + vertical_angle=torch.tensor([0.0]), + ) + drift = Drift(length=torch.tensor([1.0])) + incoming_beam = ParameterBeam.from_twiss( + energy=torch.tensor([1.8e7]), + beta_x=torch.tensor([5]), + beta_y=torch.tensor([5]), + ) + outbeam_corrector_off = corrector(incoming_beam) + outbeam_drift = drift(incoming_beam) + + corrector.horizontal_angle = torch.tensor( + [2.0], device=corrector.horizontal_angle.device + ) + corrector.vertical_angle = torch.tensor( + [2.0], device=corrector.vertical_angle.device + ) + outbeam_corrector_on = corrector(incoming_beam) + + print(outbeam_corrector_off) + print(outbeam_drift) + print(outbeam_corrector_on) + + assert corrector.name is not None + assert torch.allclose(outbeam_corrector_off.mu_xp, outbeam_drift.mu_xp) + assert torch.allclose(outbeam_corrector_off.mu_yp, outbeam_drift.mu_yp) + assert not torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) + assert not torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) + + +test_corrector_off() + + +def test_corrector_batched_execution(): + """ + Test that a corrector with batch dimensions behaves as expected. + """ + batch_shape = torch.Size([3]) + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(1000000), + energy=torch.tensor([1.8e7]), + ).broadcast(batch_shape) + segment = Segment( + [ + Corrector( + length=torch.tensor([0.04, 0.04, 0.04]), + horizontal_angle=torch.tensor([0.001, 0.003, 0.001]), + vertical_angle=torch.tensor([0.001, 0.002, 0.001]), + ), + Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + ] + ) + outgoing = segment(incoming) + + # Check that dipole with same bend angle produce same output + assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) + + # Check different angles do make a difference + assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) From e20058e6a061ea7132272d85312aa7a5b63e1b3f Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 15:47:15 +0200 Subject: [PATCH 02/14] Adding horizontal and vertical correctors Derived from the Corrector class, tests are included --- cheetah/accelerator/horizontal_corrector.py | 95 +++-------------- cheetah/accelerator/vertical_corrector.py | 96 +++-------------- .../{test_correctors.py => test_corrector.py} | 7 +- tests/test_horizontal_corrector.py | 100 ++++++++++++++++++ tests/test_vertical_corrector.py | 100 ++++++++++++++++++ 5 files changed, 229 insertions(+), 169 deletions(-) rename tests/{test_correctors.py => test_corrector.py} (94%) create mode 100644 tests/test_horizontal_corrector.py create mode 100644 tests/test_vertical_corrector.py diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index f1297101..d8b2567f 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -1,108 +1,39 @@ from typing import Optional, Union -import matplotlib.pyplot as plt -import numpy as np import torch -from matplotlib.patches import Rectangle -from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator -from .element import Element +from .corrector import Corrector generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) - -class HorizontalCorrector(Element): +class HorizontalCorrector(Corrector): """ Horizontal corrector magnet in a particle accelerator. Note: This is modeled as a drift section with a thin-kick in the horizontal plane. :param length: Length in meters. - :param angle: Particle deflection angle in the horizontal plane in rad. + :param horizontal_angle: Particle deflection angle in the horizontal plane in rad. :param name: Unique identifier of the element. """ def __init__( self, length: Union[torch.Tensor, nn.Parameter], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + # vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, dtype=torch.float32, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__(name=name) - - self.length = torch.as_tensor(length, **factory_kwargs) - self.angle = ( - torch.as_tensor(angle, **factory_kwargs) - if angle is not None - else torch.zeros_like(self.length) - ) - - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) - - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) - tm[..., 0, 1] = self.length - tm[..., 1, 6] = self.angle - tm[..., 2, 3] = self.length - tm[..., 4, 5] = -self.length / beta**2 * igamma2 - - return tm - - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), angle=self.angle, name=self.name - ) - - @property - def is_skippable(self) -> bool: - return True - - @property - def is_active(self) -> bool: - return any(self.angle != 0) - - def split(self, resolution: torch.Tensor) -> list[Element]: - split_elements = [] - remaining = self.length - while remaining > 0: - length = torch.min(resolution, remaining) - element = HorizontalCorrector(length, self.angle * length / self.length) - split_elements.append(element) - remaining -= resolution - return split_elements - - def plot(self, ax: plt.Axes, s: float) -> None: - alpha = 1 if self.is_active else 0.2 - height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1) - - patch = Rectangle( - (s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2 - ) - ax.add_patch(patch) - - @property - def defining_features(self) -> list[str]: - return super().defining_features + ["length", "angle"] - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(length={repr(self.length)}, " - + f"angle={repr(self.angle)}, " - + f"name={repr(self.name)})" + ): + super().__init__( + length=length, + horizontal_angle=horizontal_angle, + name=name, + device=device, + dtype=dtype, ) diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 47306e53..8b1593a5 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -1,107 +1,39 @@ from typing import Optional, Union -import matplotlib.pyplot as plt -import numpy as np import torch -from matplotlib.patches import Rectangle -from scipy.constants import physical_constants -from torch import Size, nn +from torch import nn from cheetah.utils import UniqueNameGenerator -from .element import Element +from .corrector import Corrector generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") -electron_mass_eV = torch.tensor( - physical_constants["electron mass energy equivalent in MeV"][0] * 1e6 -) - -class VerticalCorrector(Element): +class VerticalCorrector(Corrector): """ - Verticle corrector magnet in a particle accelerator. + Vertical corrector magnet in a particle accelerator. Note: This is modeled as a drift section with a thin-kick in the vertical plane. :param length: Length in meters. - :param angle: Particle deflection angle in the vertical plane in rad. + :param vertical_angle: Particle deflection angle in the vertical plane in rad. :param name: Unique identifier of the element. """ def __init__( self, length: Union[torch.Tensor, nn.Parameter], - angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + # horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, dtype=torch.float32, - ) -> None: - factory_kwargs = {"device": device, "dtype": dtype} - super().__init__(name=name) - - self.length = torch.as_tensor(length, **factory_kwargs) - self.angle = ( - torch.as_tensor(angle, **factory_kwargs) - if angle is not None - else torch.zeros_like(self.length) - ) - - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) - - tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1)) - tm[..., 0, 1] = self.length - tm[..., 2, 3] = self.length - tm[..., 3, 6] = self.angle - tm[..., 4, 5] = -self.length / beta**2 * igamma2 - return tm - - def broadcast(self, shape: Size) -> Element: - return self.__class__( - length=self.length.repeat(shape), angle=self.angle, name=self.name - ) - - @property - def is_skippable(self) -> bool: - return True - - @property - def is_active(self) -> bool: - return any(self.angle != 0) - - def split(self, resolution: torch.Tensor) -> list[Element]: - split_elements = [] - remaining = self.length - while remaining > 0: - length = torch.min(resolution, remaining) - element = VerticalCorrector(length, self.angle * length / self.length) - split_elements.append(element) - remaining -= resolution - return split_elements - - def plot(self, ax: plt.Axes, s: float) -> None: - alpha = 1 if self.is_active else 0.2 - height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1) - - patch = Rectangle( - (s, 0), self.length[0], height, color="tab:cyan", alpha=alpha, zorder=2 - ) - ax.add_patch(patch) - - @property - def defining_features(self) -> list[str]: - return super().defining_features + ["length", "angle"] - - def __repr__(self) -> str: - return ( - f"{self.__class__.__name__}(length={repr(self.length)}, " - + f"angle={repr(self.angle)}, " - + f"name={repr(self.name)})" + ): + super().__init__( + length=length, + vertical_angle=vertical_angle, + name=name, + device=device, + dtype=dtype, ) diff --git a/tests/test_correctors.py b/tests/test_corrector.py similarity index 94% rename from tests/test_correctors.py rename to tests/test_corrector.py index 56da7eed..a36aca8d 100644 --- a/tests/test_correctors.py +++ b/tests/test_corrector.py @@ -23,10 +23,10 @@ def test_corrector_off(): outbeam_drift = drift(incoming_beam) corrector.horizontal_angle = torch.tensor( - [2.0], device=corrector.horizontal_angle.device + [5.0], device=corrector.horizontal_angle.device ) corrector.vertical_angle = torch.tensor( - [2.0], device=corrector.vertical_angle.device + [7.0], device=corrector.vertical_angle.device ) outbeam_corrector_on = corrector(incoming_beam) @@ -41,9 +41,6 @@ def test_corrector_off(): assert not torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) -test_corrector_off() - - def test_corrector_batched_execution(): """ Test that a corrector with batch dimensions behaves as expected. diff --git a/tests/test_horizontal_corrector.py b/tests/test_horizontal_corrector.py new file mode 100644 index 00000000..5b9e1998 --- /dev/null +++ b/tests/test_horizontal_corrector.py @@ -0,0 +1,100 @@ +import torch + +from cheetah import Drift, HorizontalCorrector, ParameterBeam, ParticleBeam, Segment + + +def test_horizontal_corrector_off(): + """ + Test that a corrector with horizontal_angle=0 behaves + still like a drift and that the angle translates properly. + """ + corrector = HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([0.0]), + ) + drift = Drift(length=torch.tensor([1.0])) + incoming_beam = ParameterBeam.from_twiss( + energy=torch.tensor([1.8e7]), + beta_x=torch.tensor([5]), + beta_y=torch.tensor([5]), + ) + outbeam_corrector_off = corrector(incoming_beam) + outbeam_drift = drift(incoming_beam) + + corrector.horizontal_angle = torch.tensor( + [7.0], device=corrector.horizontal_angle.device + ) + outbeam_corrector_on = corrector(incoming_beam) + + assert corrector.name is not None + assert torch.allclose(outbeam_corrector_off.mu_xp, outbeam_drift.mu_xp) + assert torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) + assert torch.allclose(outbeam_corrector_on.mu_xp, corrector.horizontal_angle) + assert not torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) + + +def test_horizontal_angle_only(): + """ + Test that the horizontal corrector behaves as expected. + """ + try: + HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + vertical_angle=torch.tensor([7.0]), + ) + except TypeError: + assert True + + try: + HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + ) + except TypeError: + assert False + + try: + corrector = HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + ) + print(corrector.horizontal_angle) + except TypeError: + assert False + + try: + corrector = HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + ) + print(corrector.vertical_angle) + except TypeError: + assert True + + +def test_corrector_batched_execution(): + """ + Test that a corrector with batch dimensions behaves as expected. + """ + batch_shape = torch.Size([3]) + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(1000000), + energy=torch.tensor([1.8e7]), + ).broadcast(batch_shape) + segment = Segment( + [ + HorizontalCorrector( + length=torch.tensor([0.04, 0.04, 0.04]), + horizontal_angle=torch.tensor([0.001, 0.003, 0.001]), + ), + Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + ] + ) + outgoing = segment(incoming) + + # Check that dipole with same bend angle produce same output + assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) + + # Check different angles do make a difference + assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) diff --git a/tests/test_vertical_corrector.py b/tests/test_vertical_corrector.py new file mode 100644 index 00000000..f11c52fe --- /dev/null +++ b/tests/test_vertical_corrector.py @@ -0,0 +1,100 @@ +import torch + +from cheetah import Drift, ParameterBeam, ParticleBeam, Segment, VerticalCorrector + + +def test_vertical_corrector_off(): + """ + Test that a corrector with vertical_angle=0 behaves + still like a drift and that the angle translates properly. + """ + corrector = VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([0.0]), + ) + drift = Drift(length=torch.tensor([1.0])) + incoming_beam = ParameterBeam.from_twiss( + energy=torch.tensor([1.8e7]), + beta_x=torch.tensor([5]), + beta_y=torch.tensor([5]), + ) + outbeam_corrector_off = corrector(incoming_beam) + outbeam_drift = drift(incoming_beam) + + corrector.vertical_angle = torch.tensor( + [7.0], device=corrector.horizontal_angle.device + ) + outbeam_corrector_on = corrector(incoming_beam) + + assert corrector.name is not None + assert torch.allclose(outbeam_corrector_off.mu_yp, outbeam_drift.mu_yp) + assert torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) + assert torch.allclose(outbeam_corrector_on.mu_yp, corrector.vertical_angle) + assert not torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) + + +def test_vertical_angle_only(): + """ + Test that the vertical corrector behaves as expected. + """ + try: + VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + horizontal_angle=torch.tensor([7.0]), + ) + except TypeError: + assert True + + try: + VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + ) + except TypeError: + assert False + + try: + corrector = VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + ) + print(corrector.vertical_angle) + except TypeError: + assert False + + try: + corrector = VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + ) + print(corrector.horizontal_angle) + except TypeError: + assert True + + +def test_corrector_batched_execution(): + """ + Test that a corrector with batch dimensions behaves as expected. + """ + batch_shape = torch.Size([3]) + incoming = ParticleBeam.from_parameters( + num_particles=torch.tensor(1000000), + energy=torch.tensor([1.8e7]), + ).broadcast(batch_shape) + segment = Segment( + [ + VerticalCorrector( + length=torch.tensor([0.04, 0.04, 0.04]), + vertical_angle=torch.tensor([0.001, 0.003, 0.001]), + ), + Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + ] + ) + outgoing = segment(incoming) + + # Check that dipole with same bend angle produce same output + assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) + + # Check different angles do make a difference + assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) From eb9a09e8144ad405fcdfd1ff023accb564caec50 Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 16:05:09 +0200 Subject: [PATCH 03/14] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 15ba1300..8f13032a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ ### 🚨 Breaking Changes - Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198) (@jank324, @cr-xu) +- New `Corrector` class and modification of `HorizontalCorrector` and `VerticalCorrector` properties (see #207) (@ansantam) ### 🚀 Features From c182a2838e4ad20e5925185cc3dcb4607e0d5960 Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 16:16:00 +0200 Subject: [PATCH 04/14] Fixing the testing :) Random asserts gone (8) --- tests/test_horizontal_corrector.py | 28 +++++++++++----------------- tests/test_vertical_corrector.py | 28 +++++++++++----------------- 2 files changed, 22 insertions(+), 34 deletions(-) diff --git a/tests/test_horizontal_corrector.py b/tests/test_horizontal_corrector.py index 5b9e1998..d4a08219 100644 --- a/tests/test_horizontal_corrector.py +++ b/tests/test_horizontal_corrector.py @@ -44,24 +44,18 @@ def test_horizontal_angle_only(): vertical_angle=torch.tensor([7.0]), ) except TypeError: - assert True + pass - try: - HorizontalCorrector( - length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), - ) - except TypeError: - assert False + HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + ) - try: - corrector = HorizontalCorrector( - length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), - ) - print(corrector.horizontal_angle) - except TypeError: - assert False + corrector = HorizontalCorrector( + length=torch.tensor([0.3]), + horizontal_angle=torch.tensor([5.0]), + ) + print(corrector.horizontal_angle) try: corrector = HorizontalCorrector( @@ -70,7 +64,7 @@ def test_horizontal_angle_only(): ) print(corrector.vertical_angle) except TypeError: - assert True + pass def test_corrector_batched_execution(): diff --git a/tests/test_vertical_corrector.py b/tests/test_vertical_corrector.py index f11c52fe..728f898a 100644 --- a/tests/test_vertical_corrector.py +++ b/tests/test_vertical_corrector.py @@ -44,24 +44,18 @@ def test_vertical_angle_only(): horizontal_angle=torch.tensor([7.0]), ) except TypeError: - assert True + pass - try: - VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), - ) - except TypeError: - assert False + VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + ) - try: - corrector = VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), - ) - print(corrector.vertical_angle) - except TypeError: - assert False + corrector = VerticalCorrector( + length=torch.tensor([0.3]), + vertical_angle=torch.tensor([5.0]), + ) + print(corrector.vertical_angle) try: corrector = VerticalCorrector( @@ -70,7 +64,7 @@ def test_vertical_angle_only(): ) print(corrector.horizontal_angle) except TypeError: - assert True + pass def test_corrector_batched_execution(): From ba273f0fc76bfd959bb6b8baefa0670827378d53 Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 16:16:58 +0200 Subject: [PATCH 05/14] Removing individual h and v tms --- cheetah/accelerator/corrector.py | 42 -------------------------------- 1 file changed, 42 deletions(-) diff --git a/cheetah/accelerator/corrector.py b/cheetah/accelerator/corrector.py index 58455f18..1599d028 100644 --- a/cheetah/accelerator/corrector.py +++ b/cheetah/accelerator/corrector.py @@ -57,48 +57,6 @@ def __init__( else torch.zeros_like(self.length) ) - def horizontal_transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) - - h_tm = torch.eye(7, device=device, dtype=dtype).repeat( - (*self.length.shape, 1, 1) - ) - h_tm[..., 0, 1] = self.length - h_tm[..., 1, 6] = self.horizontal_angle - h_tm[..., 2, 3] = self.length - h_tm[..., 4, 5] = -self.length / beta**2 * igamma2 - - # print(h_tm) - - return h_tm - - def vertical_transfer_map(self, energy: torch.Tensor) -> torch.Tensor: - device = self.length.device - dtype = self.length.dtype - - gamma = energy / electron_mass_eV.to(device=device, dtype=dtype) - igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients? - igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2 - beta = torch.sqrt(1 - igamma2) - - v_tm = torch.eye(7, device=device, dtype=dtype).repeat( - (*self.length.shape, 1, 1) - ) - v_tm[..., 0, 1] = self.length - v_tm[..., 2, 3] = self.length - v_tm[..., 3, 6] = self.vertical_angle - v_tm[..., 4, 5] = -self.length / beta**2 * igamma2 - - # print(v_tm) - - return v_tm - def transfer_map(self, energy: torch.Tensor) -> torch.Tensor: device = self.length.device dtype = self.length.dtype From 88c77f752f9aa271861a1cf677e5453dfff2c031 Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 16:35:08 +0200 Subject: [PATCH 06/14] Adding angle property to a void breaking changes --- cheetah/accelerator/horizontal_corrector.py | 13 ++++++-- cheetah/accelerator/vertical_corrector.py | 13 ++++++-- tests/test_horizontal_corrector.py | 34 ++++++------------- tests/test_vertical_corrector.py | 36 ++++++--------------- 4 files changed, 39 insertions(+), 57 deletions(-) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index d8b2567f..9a8c701a 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -24,16 +24,23 @@ class HorizontalCorrector(Corrector): def __init__( self, length: Union[torch.Tensor, nn.Parameter], - horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, - # vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, dtype=torch.float32, ): super().__init__( length=length, - horizontal_angle=horizontal_angle, + horizontal_angle=angle, name=name, device=device, dtype=dtype, ) + + @property + def angle(self) -> torch.Tensor: + return self.horizontal_angle + + @angle.setter + def angle(self, value: torch.Tensor) -> None: + self.horizontal_angle = value diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 8b1593a5..a7799b18 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -24,16 +24,23 @@ class VerticalCorrector(Corrector): def __init__( self, length: Union[torch.Tensor, nn.Parameter], - # horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, - vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, + angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, dtype=torch.float32, ): super().__init__( length=length, - vertical_angle=vertical_angle, + vertical_angle=angle, name=name, device=device, dtype=dtype, ) + + @property + def angle(self) -> torch.Tensor: + return self.vertical_angle + + @angle.setter + def angle(self, value: torch.Tensor) -> None: + self.vertical_angle = value diff --git a/tests/test_horizontal_corrector.py b/tests/test_horizontal_corrector.py index d4a08219..28162e14 100644 --- a/tests/test_horizontal_corrector.py +++ b/tests/test_horizontal_corrector.py @@ -10,7 +10,7 @@ def test_horizontal_corrector_off(): """ corrector = HorizontalCorrector( length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([0.0]), + angle=torch.tensor([0.0]), ) drift = Drift(length=torch.tensor([1.0])) incoming_beam = ParameterBeam.from_twiss( @@ -21,48 +21,32 @@ def test_horizontal_corrector_off(): outbeam_corrector_off = corrector(incoming_beam) outbeam_drift = drift(incoming_beam) - corrector.horizontal_angle = torch.tensor( - [7.0], device=corrector.horizontal_angle.device - ) + corrector.angle = torch.tensor([7.0], device=corrector.angle.device) outbeam_corrector_on = corrector(incoming_beam) assert corrector.name is not None assert torch.allclose(outbeam_corrector_off.mu_xp, outbeam_drift.mu_xp) assert torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) - assert torch.allclose(outbeam_corrector_on.mu_xp, corrector.horizontal_angle) + assert torch.allclose(outbeam_corrector_on.mu_xp, corrector.angle) assert not torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) -def test_horizontal_angle_only(): - """ - Test that the horizontal corrector behaves as expected. - """ +def test_vertical_angle_property(): try: HorizontalCorrector( length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), - vertical_angle=torch.tensor([7.0]), + vertical_angle=torch.tensor([0.0]), ) except TypeError: pass - HorizontalCorrector( - length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), - ) - - corrector = HorizontalCorrector( - length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), - ) - print(corrector.horizontal_angle) +def test_horizontal_angle_property(): try: - corrector = HorizontalCorrector( + HorizontalCorrector( length=torch.tensor([0.3]), - horizontal_angle=torch.tensor([5.0]), + vertical_angle=torch.tensor([0.0]), ) - print(corrector.vertical_angle) except TypeError: pass @@ -80,7 +64,7 @@ def test_corrector_batched_execution(): [ HorizontalCorrector( length=torch.tensor([0.04, 0.04, 0.04]), - horizontal_angle=torch.tensor([0.001, 0.003, 0.001]), + angle=torch.tensor([0.001, 0.003, 0.001]), ), Drift(length=torch.tensor([0.5])).broadcast(batch_shape), ] diff --git a/tests/test_vertical_corrector.py b/tests/test_vertical_corrector.py index 728f898a..dd798934 100644 --- a/tests/test_vertical_corrector.py +++ b/tests/test_vertical_corrector.py @@ -5,12 +5,12 @@ def test_vertical_corrector_off(): """ - Test that a corrector with vertical_angle=0 behaves + Test that a corrector with angle=0 behaves still like a drift and that the angle translates properly. """ corrector = VerticalCorrector( length=torch.tensor([0.3]), - vertical_angle=torch.tensor([0.0]), + angle=torch.tensor([0.0]), ) drift = Drift(length=torch.tensor([1.0])) incoming_beam = ParameterBeam.from_twiss( @@ -21,48 +21,32 @@ def test_vertical_corrector_off(): outbeam_corrector_off = corrector(incoming_beam) outbeam_drift = drift(incoming_beam) - corrector.vertical_angle = torch.tensor( - [7.0], device=corrector.horizontal_angle.device - ) + corrector.angle = torch.tensor([7.0], device=corrector.angle.device) outbeam_corrector_on = corrector(incoming_beam) assert corrector.name is not None assert torch.allclose(outbeam_corrector_off.mu_yp, outbeam_drift.mu_yp) assert torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) - assert torch.allclose(outbeam_corrector_on.mu_yp, corrector.vertical_angle) + assert torch.allclose(outbeam_corrector_on.mu_yp, corrector.angle) assert not torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) -def test_vertical_angle_only(): - """ - Test that the vertical corrector behaves as expected. - """ +def test_vertical_angle_property(): try: VerticalCorrector( length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), - horizontal_angle=torch.tensor([7.0]), + vertical_angle=torch.tensor([0.0]), ) except TypeError: pass - VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), - ) - - corrector = VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), - ) - print(corrector.vertical_angle) +def test_horizontal_angle_property(): try: - corrector = VerticalCorrector( + VerticalCorrector( length=torch.tensor([0.3]), - vertical_angle=torch.tensor([5.0]), + vertical_angle=torch.tensor([0.0]), ) - print(corrector.horizontal_angle) except TypeError: pass @@ -80,7 +64,7 @@ def test_corrector_batched_execution(): [ VerticalCorrector( length=torch.tensor([0.04, 0.04, 0.04]), - vertical_angle=torch.tensor([0.001, 0.003, 0.001]), + angle=torch.tensor([0.001, 0.003, 0.001]), ), Drift(length=torch.tensor([0.5])).broadcast(batch_shape), ] From e1ebb35ae32216d54564827568591b52952db3de Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 16:56:46 +0200 Subject: [PATCH 07/14] Redefining the broadcast function for the other correctors with angle --- cheetah/accelerator/horizontal_corrector.py | 10 +++++++++- cheetah/accelerator/vertical_corrector.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 9a8c701a..6dbaf892 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -1,11 +1,12 @@ from typing import Optional, Union import torch -from torch import nn +from torch import Size, nn from cheetah.utils import UniqueNameGenerator from .corrector import Corrector +from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -37,6 +38,13 @@ def __init__( dtype=dtype, ) + def broadcast(self, shape: Size) -> Element: + return self.__class__( + length=self.length.repeat(shape), + angle=self.angle, + name=self.name, + ) + @property def angle(self) -> torch.Tensor: return self.horizontal_angle diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index a7799b18..96b582dd 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -1,11 +1,12 @@ from typing import Optional, Union import torch -from torch import nn +from torch import Size, nn from cheetah.utils import UniqueNameGenerator from .corrector import Corrector +from .element import Element generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -37,6 +38,13 @@ def __init__( dtype=dtype, ) + def broadcast(self, shape: Size) -> Element: + return self.__class__( + length=self.length.repeat(shape), + angle=self.angle, + name=self.name, + ) + @property def angle(self) -> torch.Tensor: return self.vertical_angle From 68cf1946ba5b752d4f92d94afdd031de5daf07e1 Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 17:14:37 +0200 Subject: [PATCH 08/14] Trying to define angle as correct param To fix test_lattice_json.py --- cheetah/accelerator/horizontal_corrector.py | 7 +++++++ cheetah/accelerator/vertical_corrector.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 6dbaf892..882bae73 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -52,3 +52,10 @@ def angle(self) -> torch.Tensor: @angle.setter def angle(self, value: torch.Tensor) -> None: self.horizontal_angle = value + + @property + def defining_features(self) -> list[str]: + return super().defining_features + [ + "length", + "angle", + ] diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 96b582dd..823941bb 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -52,3 +52,10 @@ def angle(self) -> torch.Tensor: @angle.setter def angle(self, value: torch.Tensor) -> None: self.vertical_angle = value + + @property + def defining_features(self) -> list[str]: + return super().defining_features + [ + "length", + "angle", + ] From 870536dab88be1bdd48466ad44c74b52def75d0f Mon Sep 17 00:00:00 2001 From: Andrea Santamaria Garcia Date: Fri, 28 Jun 2024 17:20:11 +0200 Subject: [PATCH 09/14] Second attempt test lattice json --- cheetah/accelerator/horizontal_corrector.py | 7 +++++++ cheetah/accelerator/vertical_corrector.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 882bae73..24ade0ab 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -59,3 +59,10 @@ def defining_features(self) -> list[str]: "length", "angle", ] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"angle={repr(self.angle)}, " + + f"name={repr(self.name)})" + ) diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 823941bb..081f9c1b 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -59,3 +59,10 @@ def defining_features(self) -> list[str]: "length", "angle", ] + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(length={repr(self.length)}, " + + f"angle={repr(self.angle)}, " + + f"name={repr(self.name)})" + ) From b312ba5269a6f4f65bfa86405dde9d01e891bb92 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 3 Jul 2024 13:32:16 +0200 Subject: [PATCH 10/14] Add corrector to documentation --- docs/accelerator.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/accelerator.rst b/docs/accelerator.rst index cd5b998a..41437556 100644 --- a/docs/accelerator.rst +++ b/docs/accelerator.rst @@ -15,6 +15,10 @@ Accelerator :members: :undoc-members: +.. automodule:: accelerator.corrector + :members: + :undoc-members: + .. automodule:: accelerator.custom_transfer_map :members: :undoc-members: From dff3c7dd1d786941b625a7a10d02dd78600b09c6 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 3 Jul 2024 13:44:26 +0200 Subject: [PATCH 11/14] Update version Cheetah writes into lattice JSON files --- cheetah/latticejson.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cheetah/latticejson.py b/cheetah/latticejson.py index 0a350360..414d1677 100644 --- a/cheetah/latticejson.py +++ b/cheetah/latticejson.py @@ -92,7 +92,7 @@ def save_cheetah_model( title = segment.name if segment.name is not None else "Unnamed Lattice" metadata = { - "version": "cheetah-0.6", + "version": "cheetah-0.7", "title": title, "info": info, "root": segment.name if segment.name is not None else "cell", From a6d5bcf0cb1c24fe6f4e2113431b6ab452c53bcf Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 3 Jul 2024 13:44:51 +0200 Subject: [PATCH 12/14] First test failing because of inherited defining properties --- cheetah/accelerator/horizontal_corrector.py | 9 ++++++--- cheetah/accelerator/vertical_corrector.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index 24ade0ab..1c8819f0 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -55,10 +55,13 @@ def angle(self, value: torch.Tensor) -> None: @property def defining_features(self) -> list[str]: - return super().defining_features + [ - "length", - "angle", + features_with_both_angles = super().defining_features + cleaned_features = [ + feature + for feature in features_with_both_angles + if feature not in ["horizontal_angle", "vertical_angle"] ] + return cleaned_features + ["length", "angle"] def __repr__(self) -> str: return ( diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index 081f9c1b..8a1700b2 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -55,10 +55,13 @@ def angle(self, value: torch.Tensor) -> None: @property def defining_features(self) -> list[str]: - return super().defining_features + [ - "length", - "angle", + features_with_both_angles = super().defining_features + cleaned_features = [ + feature + for feature in features_with_both_angles + if feature not in ["horizontal_angle", "vertical_angle"] ] + return cleaned_features + ["length", "angle"] def __repr__(self) -> str: return ( From 266b601dc48ecf51c1e7d4449953f303bee04c14 Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Wed, 3 Jul 2024 13:54:05 +0200 Subject: [PATCH 13/14] Fix docstring line length in `Corrector` --- cheetah/accelerator/corrector.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cheetah/accelerator/corrector.py b/cheetah/accelerator/corrector.py index 1599d028..a31eb34e 100644 --- a/cheetah/accelerator/corrector.py +++ b/cheetah/accelerator/corrector.py @@ -20,16 +20,16 @@ class Corrector(Element): """ - Corrector magnet in a particle accelerator. - Note: This is modeled as a drift section with - a thin-kick in the horizontal plane followed by - a thin-kick in the vertical plane. + Combined corrector magnet in a particle accelerator. + + Note: This is modeled as a drift section with a thin-kick in the horizontal plane + followed by a thin-kick in the vertical plane. :param length: Length in meters. - :param horizontal_angle: Particle deflection horizontal_angle in - the horizontal plane in rad. - :param vertical_angle: Particle deflection vertical_angle in - the vertical plane in rad. + :param horizontal_angle: Particle deflection horizontal_angle in the horizontal + plane in rad. + :param vertical_angle: Particle deflection vertical_angle in the vertical plane in + rad. :param name: Unique identifier of the element. """ From 7dde118f6243d9ef7501858f199c2576b3c410bb Mon Sep 17 00:00:00 2001 From: Jan Kaiser Date: Thu, 4 Jul 2024 13:45:22 +0200 Subject: [PATCH 14/14] Some minor review of corrector tests --- tests/test_corrector.py | 31 +++++++++++++----------------- tests/test_horizontal_corrector.py | 25 ++++++++++-------------- tests/test_vertical_corrector.py | 23 ++++++++-------------- 3 files changed, 31 insertions(+), 48 deletions(-) diff --git a/tests/test_corrector.py b/tests/test_corrector.py index a36aca8d..fc2f641f 100644 --- a/tests/test_corrector.py +++ b/tests/test_corrector.py @@ -3,10 +3,11 @@ from cheetah import Corrector, Drift, ParameterBeam, ParticleBeam, Segment -def test_corrector_off(): +def test_corrector_off_on(): """ - Test that a corrector with horizontal_angle=0 and vertical_angle=0 behaves - still like a drift. + Test that a corrector with horizontal_angle=0 and vertical_angle=0 behaves still + like a drift, but when angles are different from 0, it behaves differently from a + drift. """ corrector = Corrector( length=torch.tensor([0.3]), @@ -14,11 +15,11 @@ def test_corrector_off(): vertical_angle=torch.tensor([0.0]), ) drift = Drift(length=torch.tensor([1.0])) + incoming_beam = ParameterBeam.from_twiss( - energy=torch.tensor([1.8e7]), - beta_x=torch.tensor([5]), - beta_y=torch.tensor([5]), + energy=torch.tensor([1.8e7]), beta_x=torch.tensor([5]), beta_y=torch.tensor([5]) ) + outbeam_corrector_off = corrector(incoming_beam) outbeam_drift = drift(incoming_beam) @@ -30,11 +31,6 @@ def test_corrector_off(): ) outbeam_corrector_on = corrector(incoming_beam) - print(outbeam_corrector_off) - print(outbeam_drift) - print(outbeam_corrector_on) - - assert corrector.name is not None assert torch.allclose(outbeam_corrector_off.mu_xp, outbeam_drift.mu_xp) assert torch.allclose(outbeam_corrector_off.mu_yp, outbeam_drift.mu_yp) assert not torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) @@ -45,11 +41,10 @@ def test_corrector_batched_execution(): """ Test that a corrector with batch dimensions behaves as expected. """ - batch_shape = torch.Size([3]) + shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), - energy=torch.tensor([1.8e7]), - ).broadcast(batch_shape) + num_particles=torch.tensor(1_000_000), energy=torch.tensor([1.8e7]) + ).broadcast(shape) segment = Segment( [ Corrector( @@ -57,13 +52,13 @@ def test_corrector_batched_execution(): horizontal_angle=torch.tensor([0.001, 0.003, 0.001]), vertical_angle=torch.tensor([0.001, 0.002, 0.001]), ), - Drift(length=torch.tensor([0.5])).broadcast(batch_shape), + Drift(length=torch.tensor([0.5])).broadcast(shape), ] ) outgoing = segment(incoming) - # Check that dipole with same bend angle produce same output + # Check that a dipole with same bend angle produces the same output assert torch.allclose(outgoing.particles[0], outgoing.particles[2]) - # Check different angles do make a difference + # Check if different angles do make a difference assert not torch.allclose(outgoing.particles[0], outgoing.particles[1]) diff --git a/tests/test_horizontal_corrector.py b/tests/test_horizontal_corrector.py index 28162e14..0c740159 100644 --- a/tests/test_horizontal_corrector.py +++ b/tests/test_horizontal_corrector.py @@ -3,28 +3,25 @@ from cheetah import Drift, HorizontalCorrector, ParameterBeam, ParticleBeam, Segment -def test_horizontal_corrector_off(): +def test_horizontal_corrector_off_on(): """ - Test that a corrector with horizontal_angle=0 behaves - still like a drift and that the angle translates properly. + Test that a corrector with horizontal_angle=0 behaves still like a drift and that + the angle translates properly. """ corrector = HorizontalCorrector( - length=torch.tensor([0.3]), - angle=torch.tensor([0.0]), + length=torch.tensor([0.3]), angle=torch.tensor([0.0]) ) drift = Drift(length=torch.tensor([1.0])) incoming_beam = ParameterBeam.from_twiss( - energy=torch.tensor([1.8e7]), - beta_x=torch.tensor([5]), - beta_y=torch.tensor([5]), + energy=torch.tensor([1.8e7]), beta_x=torch.tensor([5]), beta_y=torch.tensor([5]) ) + outbeam_corrector_off = corrector(incoming_beam) outbeam_drift = drift(incoming_beam) corrector.angle = torch.tensor([7.0], device=corrector.angle.device) outbeam_corrector_on = corrector(incoming_beam) - assert corrector.name is not None assert torch.allclose(outbeam_corrector_off.mu_xp, outbeam_drift.mu_xp) assert torch.allclose(outbeam_corrector_on.mu_yp, outbeam_drift.mu_yp) assert torch.allclose(outbeam_corrector_on.mu_xp, corrector.angle) @@ -34,8 +31,7 @@ def test_horizontal_corrector_off(): def test_vertical_angle_property(): try: HorizontalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([0.0]), + length=torch.tensor([0.3]), vertical_angle=torch.tensor([0.0]) ) except TypeError: pass @@ -44,8 +40,7 @@ def test_vertical_angle_property(): def test_horizontal_angle_property(): try: HorizontalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([0.0]), + length=torch.tensor([0.3]), vertical_angle=torch.tensor([0.0]) ) except TypeError: pass @@ -57,8 +52,7 @@ def test_corrector_batched_execution(): """ batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), - energy=torch.tensor([1.8e7]), + num_particles=torch.tensor(1_000_000), energy=torch.tensor([1.8e7]) ).broadcast(batch_shape) segment = Segment( [ @@ -69,6 +63,7 @@ def test_corrector_batched_execution(): Drift(length=torch.tensor([0.5])).broadcast(batch_shape), ] ) + outgoing = segment(incoming) # Check that dipole with same bend angle produce same output diff --git a/tests/test_vertical_corrector.py b/tests/test_vertical_corrector.py index dd798934..998cd22b 100644 --- a/tests/test_vertical_corrector.py +++ b/tests/test_vertical_corrector.py @@ -3,28 +3,23 @@ from cheetah import Drift, ParameterBeam, ParticleBeam, Segment, VerticalCorrector -def test_vertical_corrector_off(): +def test_vertical_corrector_off_on(): """ Test that a corrector with angle=0 behaves still like a drift and that the angle translates properly. """ - corrector = VerticalCorrector( - length=torch.tensor([0.3]), - angle=torch.tensor([0.0]), - ) + corrector = VerticalCorrector(length=torch.tensor([0.3]), angle=torch.tensor([0.0])) drift = Drift(length=torch.tensor([1.0])) incoming_beam = ParameterBeam.from_twiss( - energy=torch.tensor([1.8e7]), - beta_x=torch.tensor([5]), - beta_y=torch.tensor([5]), + energy=torch.tensor([1.8e7]), beta_x=torch.tensor([5]), beta_y=torch.tensor([5]) ) + outbeam_corrector_off = corrector(incoming_beam) outbeam_drift = drift(incoming_beam) corrector.angle = torch.tensor([7.0], device=corrector.angle.device) outbeam_corrector_on = corrector(incoming_beam) - assert corrector.name is not None assert torch.allclose(outbeam_corrector_off.mu_yp, outbeam_drift.mu_yp) assert torch.allclose(outbeam_corrector_on.mu_xp, outbeam_drift.mu_xp) assert torch.allclose(outbeam_corrector_on.mu_yp, corrector.angle) @@ -34,8 +29,7 @@ def test_vertical_corrector_off(): def test_vertical_angle_property(): try: VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([0.0]), + length=torch.tensor([0.3]), vertical_angle=torch.tensor([0.0]) ) except TypeError: pass @@ -44,8 +38,7 @@ def test_vertical_angle_property(): def test_horizontal_angle_property(): try: VerticalCorrector( - length=torch.tensor([0.3]), - vertical_angle=torch.tensor([0.0]), + length=torch.tensor([0.3]), vertical_angle=torch.tensor([0.0]) ) except TypeError: pass @@ -57,8 +50,7 @@ def test_corrector_batched_execution(): """ batch_shape = torch.Size([3]) incoming = ParticleBeam.from_parameters( - num_particles=torch.tensor(1000000), - energy=torch.tensor([1.8e7]), + num_particles=torch.tensor(1000000), energy=torch.tensor([1.8e7]) ).broadcast(batch_shape) segment = Segment( [ @@ -69,6 +61,7 @@ def test_corrector_batched_execution(): Drift(length=torch.tensor([0.5])).broadcast(batch_shape), ] ) + outgoing = segment(incoming) # Check that dipole with same bend angle produce same output