-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve default dtype selection #254
base: master
Are you sure you want to change the base?
Conversation
I've noticed that the docstrings for some of the methods I have edited are not consistent with the method signatures, especially in case of the |
In principle, the methods for manual initialization of For many of the methods, I currently only applied the minimal change of removing the explicit default of |
@jank324 It seems to me like the remaining test failures are uncovered but not caused by the changes of this PR. The following code also crashes on incoming = cheetah.ParticleBeam.from_parameters(
num_particles=100_000, energy=torch.tensor([154e6, 14e9])
)
element = cheetah.TransverseDeflectingCavity(
length=torch.tensor(1.0),
phase=torch.tensor([[0.6], [0.5], [0.4]]),
)
outgoing = element.track(incoming) Maybe this is also relevant for #165 ? |
Created #270 for the issue with the deflecting cavity. |
Description
Many methods currently default to creating tensors with
dtype=torch.float32
regardless of the data provided to them or the defaultdtype
configured PyTorch. This PR changes those methods to either keep the samedtype
as their arguments or fall back totorch.get_default_dtype()
if no reasonable choice is available.Similar changes are implemented for the
device
of said tensors.Since this PR is changing the default arguments of a number of methods, it should be considered a breaking change. The impact is likely negligable because the new fallback
torch.get_default_dtype()
aligns with the previous default oftorch.float32
if no explicit actions are taken by the user.Motivation and Context
Currently, it is rather cumbersome to track particles with double precision since many methods default to
torch.float32
. Implementing this change will increase the compatability applications that require tracking usingtorch.float64
(ortorch.float16
).Types of changes
Checklist
flake8
(required).pytest
tests pass (required).pytest
on a machine with a CUDA GPU and made sure all tests pass (required).Note: We are using a maximum length of 88 characters per line.