Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new combined Corrector element #207

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### 🚨 Breaking Changes

- Cheetah is now vectorised. This means that you can run multiple simulations in parallel by passing a batch of beams and settings, resulting a number of interfaces being changed. For Cheetah developers this means that you now have to account for an arbitrary-dimensional tensor of most of the properties of you element, rather than a single value, vector or whatever else a property was before. (see #116, #157, #170, #172, #173, #198) (@jank324, @cr-xu)
- New `Corrector` class and modification of `HorizontalCorrector` and `VerticalCorrector` properties (see #207) (@ansantam)

### 🚀 Features

Expand Down
1 change: 1 addition & 0 deletions cheetah/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
BPM,
Aperture,
Cavity,
Corrector,
CustomTransferMap,
Dipole,
Drift,
Expand Down
1 change: 1 addition & 0 deletions cheetah/accelerator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .aperture import Aperture # noqa: F401
from .bpm import BPM # noqa: F401
from .cavity import Cavity # noqa: F401
from .corrector import Corrector # noqa: F401
from .custom_transfer_map import CustomTransferMap # noqa: F401
from .dipole import Dipole # noqa: F401
from .drift import Drift # noqa: F401
Expand Down
139 changes: 139 additions & 0 deletions cheetah/accelerator/corrector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.utils import UniqueNameGenerator

from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
)


class Corrector(Element):
"""
Combined corrector magnet in a particle accelerator.

Note: This is modeled as a drift section with a thin-kick in the horizontal plane
followed by a thin-kick in the vertical plane.

:param length: Length in meters.
:param horizontal_angle: Particle deflection horizontal_angle in the horizontal
plane in rad.
:param vertical_angle: Particle deflection vertical_angle in the vertical plane in
rad.
:param name: Unique identifier of the element.
"""

def __init__(
self,
length: Union[torch.Tensor, nn.Parameter],
horizontal_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
vertical_angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"horizontal_angle",
(
torch.as_tensor(horizontal_angle, **factory_kwargs)
if horizontal_angle is not None
else torch.zeros_like(self.length)
),
)
self.register_buffer(
"vertical_angle",
(
torch.as_tensor(vertical_angle, **factory_kwargs)
if vertical_angle is not None
else torch.zeros_like(self.length)
),
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

gamma = energy / electron_mass_eV.to(device=device, dtype=dtype)
igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients?
igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2
beta = torch.sqrt(1 - igamma2)

tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
tm[..., 0, 1] = self.length
tm[..., 2, 3] = self.length
tm[..., 1, 6] = self.horizontal_angle
tm[..., 3, 6] = self.vertical_angle
tm[..., 4, 5] = -self.length / beta**2 * igamma2

return tm

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape),
horizontal_angle=self.horizontal_angle,
vertical_angle=self.vertical_angle,
name=self.name,
)

@property
def is_skippable(self) -> bool:
return True

@property
def is_active(self) -> bool:
return any(self.horizontal_angle != 0, self.vertical_angle != 0)

def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
remaining = self.length
while remaining > 0:
length = torch.min(resolution, remaining)
element = Corrector(
length,
self.horizontal_angle * length / self.length,
self.vertical_angle * length / self.length,
)
split_elements.append(element)
remaining -= resolution
return split_elements

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = (np.sign(self.horizontal_angle[0]) if self.is_active else 1) * (
np.sign(self.vertical_angle[0]) if self.is_active else 1
)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
)
ax.add_patch(patch)

@property
def defining_features(self) -> list[str]:
return super().defining_features + [
"length",
"horizontal_angle",
"vertical_angle",
]

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(length={repr(self.length)}, "
+ f"horizontal_angle={repr(self.horizontal_angle)}, "
+ f"vertical_angle={repr(self.vertical_angle)}, "
+ f"name={repr(self.name)})"
)
95 changes: 25 additions & 70 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from typing import Optional, Union

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.patches import Rectangle
from scipy.constants import physical_constants
from torch import Size, nn

from cheetah.utils import UniqueNameGenerator

from .corrector import Corrector
from .element import Element

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

electron_mass_eV = torch.tensor(
physical_constants["electron mass energy equivalent in MeV"][0] * 1e6
)


class HorizontalCorrector(Element):
class HorizontalCorrector(Corrector):
"""
Horizontal corrector magnet in a particle accelerator.
Note: This is modeled as a drift section with
a thin-kick in the horizontal plane.

:param length: Length in meters.
:param angle: Particle deflection angle in the horizontal plane in rad.
:param horizontal_angle: Particle deflection angle in the horizontal plane in rad.
:param name: Unique identifier of the element.
"""

Expand All @@ -36,77 +29,39 @@ def __init__(
name: Optional[str] = None,
device=None,
dtype=torch.float32,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

self.register_buffer("length", torch.as_tensor(length, **factory_kwargs))
self.register_buffer(
"angle",
(
torch.as_tensor(angle, **factory_kwargs)
if angle is not None
else torch.zeros_like(self.length)
),
):
super().__init__(
length=length,
horizontal_angle=angle,
name=name,
device=device,
dtype=dtype,
)

def transfer_map(self, energy: torch.Tensor) -> torch.Tensor:
device = self.length.device
dtype = self.length.dtype

gamma = energy / electron_mass_eV.to(device=device, dtype=dtype)
igamma2 = torch.zeros_like(gamma) # TODO: Effect on gradients?
igamma2[gamma != 0] = 1 / gamma[gamma != 0] ** 2
beta = torch.sqrt(1 - igamma2)

tm = torch.eye(7, device=device, dtype=dtype).repeat((*self.length.shape, 1, 1))
tm[..., 0, 1] = self.length
tm[..., 1, 6] = self.angle
tm[..., 2, 3] = self.length
tm[..., 4, 5] = -self.length / beta**2 * igamma2

return tm

def broadcast(self, shape: Size) -> Element:
return self.__class__(
length=self.length.repeat(shape), angle=self.angle, name=self.name
length=self.length.repeat(shape),
angle=self.angle,
name=self.name,
)

@property
def is_skippable(self) -> bool:
return True
def angle(self) -> torch.Tensor:
return self.horizontal_angle

@property
def is_active(self) -> bool:
return any(self.angle != 0)

def split(self, resolution: torch.Tensor) -> list[Element]:
split_elements = []
remaining = self.length
while remaining > 0:
length = torch.min(resolution, remaining)
element = HorizontalCorrector(
length,
self.angle * length / self.length,
dtype=self.length.dtype,
device=self.length.device,
)
split_elements.append(element)
remaining -= resolution
return split_elements

def plot(self, ax: plt.Axes, s: float) -> None:
alpha = 1 if self.is_active else 0.2
height = 0.8 * (np.sign(self.angle[0]) if self.is_active else 1)

patch = Rectangle(
(s, 0), self.length[0], height, color="tab:blue", alpha=alpha, zorder=2
)
ax.add_patch(patch)
@angle.setter
def angle(self, value: torch.Tensor) -> None:
self.horizontal_angle = value

@property
def defining_features(self) -> list[str]:
return super().defining_features + ["length", "angle"]
features_with_both_angles = super().defining_features
cleaned_features = [
feature
for feature in features_with_both_angles
if feature not in ["horizontal_angle", "vertical_angle"]
]
return cleaned_features + ["length", "angle"]

def __repr__(self) -> str:
return (
Expand Down
Loading
Loading