Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve default dtype selection #254

Open
wants to merge 17 commits into
base: master
Choose a base branch
from

Conversation

Hespe
Copy link
Member

@Hespe Hespe commented Sep 17, 2024

Description

Many methods currently default to creating tensors with dtype=torch.float32 regardless of the data provided to them or the default dtype configured PyTorch. This PR changes those methods to either keep the same dtype as their arguments or fall back to torch.get_default_dtype() if no reasonable choice is available.
Similar changes are implemented for the device of said tensors.

Since this PR is changing the default arguments of a number of methods, it should be considered a breaking change. The impact is likely negligable because the new fallback torch.get_default_dtype() aligns with the previous default of torch.float32 if no explicit actions are taken by the user.

Motivation and Context

Currently, it is rather cumbersome to track particles with double precision since many methods default to torch.float32. Implementing this change will increase the compatability applications that require tracking using torch.float64 (or torch.float16).

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)

Checklist

  • I have updated the changelog accordingly (required).
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.
  • I have reformatted the code and checked that formatting passes (required).
  • I have have fixed all issues found by flake8 (required).
  • I have ensured that all pytest tests pass (required).
  • I have run pytest on a machine with a CUDA GPU and made sure all tests pass (required).
  • I have checked that the documentation builds (required).

Note: We are using a maximum length of 88 characters per line.

@Hespe Hespe added the enhancement New feature or request label Sep 17, 2024
@Hespe Hespe linked an issue Sep 17, 2024 that may be closed by this pull request
@Hespe
Copy link
Member Author

Hespe commented Sep 17, 2024

I've noticed that the docstrings for some of the methods I have edited are not consistent with the method signatures, especially in case of the ParticleBeam. This is in addition to the issues already noted in #238

@Hespe Hespe self-assigned this Sep 20, 2024
@Hespe
Copy link
Member Author

Hespe commented Sep 23, 2024

In principle, the methods for manual initialization of Beam or Element objects should be done. We probably also want to change the methods for importing lattices and beams defined in other tools, but I would argue its probably best to leave that for a separate PR since this one is already touching enough files as is.

For many of the methods, I currently only applied the minimal change of removing the explicit default of torch.float32, leaving the dtype selection to the built-in PyTorch functions. There might however be an issue with that approach. By supplying dtype=None, the default behaviour of the built-in functions is to keep the dtype of its input argument. Therefore, the Element might have conflicting dtype for its parameters if different types are passed to the constructors.
We could either say that it's the responsibility of the user to ensure a consistent dtype is used, or we have to try to infer the dtype on our own from all not None arguments, similar to the from_twiss() function for Beam.

@Hespe Hespe marked this pull request as ready for review September 24, 2024 12:58
@Hespe
Copy link
Member Author

Hespe commented Oct 7, 2024

@jank324 It seems to me like the remaining test failures are uncovered but not caused by the changes of this PR. The following code also crashes on master:

incoming = cheetah.ParticleBeam.from_parameters(
    num_particles=100_000, energy=torch.tensor([154e6, 14e9])
)
element = cheetah.TransverseDeflectingCavity(
    length=torch.tensor(1.0),
    phase=torch.tensor([[0.6], [0.5], [0.4]]),
)
outgoing = element.track(incoming)

Maybe this is also relevant for #165 ?

@Hespe
Copy link
Member Author

Hespe commented Oct 7, 2024

Created #270 for the issue with the deflecting cavity.

@jank324
Copy link
Member

jank324 commented Oct 16, 2024

Created #270 for the issue with the deflecting cavity.

I just merged the updated master with the fix for #270 in here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Fall back to PyTorch default dtype if no explicit type is provided
2 participants