Skip to content

Commit

Permalink
refactor: merge WeightedDataGenerator into DataGenerator (#458)
Browse files Browse the repository at this point in the history
* chore: upgrade Jupyter notebook kernels
* docs: add link to TR-018
* feat: embed weights as key to `DataSample`
* feat: implement phase space weights in `UnbinnedNLL`
  • Loading branch information
redeboer authored Aug 9, 2022
1 parent 811c17b commit 642e774
Show file tree
Hide file tree
Showing 15 changed files with 128 additions and 52 deletions.
18 changes: 16 additions & 2 deletions docs/amplitude-analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,12 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"::::{margin}\n",
":::{tip}\n",
"{doc}`TR-018<compwa-org:report/018>` explains some of the mechanisms behind the phase space generator as well as how to do {ref}`importance sampling<compwa-org:report/018:Intensity distribution>`.\n",
":::\n",
"::::\n",
"\n",
"In this section, we use the {class}`~ampform.helicity.HelicityModel` that we created with {mod}`ampform` in {ref}`the previous step <compwa-step-1>` to generate a data sample via hit & miss Monte Carlo. We do this with the {mod}`.data` module.\n",
"\n",
"First, we {func}`~pickle.load` the {class}`~ampform.helicity.HelicityModel` that was created in the previous step. This does not have to be done if the model has been generated in the same script or notebook, but can be useful if the model was generated elsewhere."
Expand Down Expand Up @@ -353,7 +359,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is an implementation of the {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
"The {class}`~qrules.transition.ReactionInfo` class defines the constraints of the phase space. As such, we have enough information to generate a **phase-space sample** for this particle reaction. We do this with a {class}`.TFPhaseSpaceGenerator` class, which is a {class}`.DataGenerator` for a {obj}`.DataSample` of **four-momenta** arrays (using {obj}`tensorflow <tf.Tensor>` and the [`phasespace`](https://phasespace.readthedocs.io) package as a back-end). We also need to construct a {class}`.RealNumberGenerator` that can generate random numbers. {class}`.TFUniformRealNumberGenerator` is the natural choice here.\n",
"\n",
"As opposed to the main {ref}`amplitude-analysis:Step 2: Generate data` of the main usage example page, we will generate a **deterministic** data sample. This can be done by feeding a {class}`.RealNumberGenerator` with a specific {attr}`~.RealNumberGenerator.seed` and giving that generator to the {meth}`.TFPhaseSpaceGenerator.generate` method:"
]
Expand Down Expand Up @@ -1935,8 +1941,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
8 changes: 8 additions & 0 deletions docs/amplitude-analysis/analytic-continuation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,15 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
Expand Down
10 changes: 9 additions & 1 deletion docs/usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -857,8 +857,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1291,8 +1291,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/binned-fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/caching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/chi-squared.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/faster-lambdify.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
10 changes: 9 additions & 1 deletion docs/usage/unbinned-fit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,16 @@
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.8.12"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
Expand Down
19 changes: 7 additions & 12 deletions src/tensorwaves/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
DataTransformer,
Function,
RealNumberGenerator,
WeightedDataGenerator,
)

from ._data_sample import (
Expand Down Expand Up @@ -71,7 +70,7 @@ class IntensityDistributionGenerator(DataGenerator):

def __init__(
self,
domain_generator: DataGenerator | WeightedDataGenerator,
domain_generator: DataGenerator,
function: Function,
domain_transformer: DataTransformer | None = None,
bunch_size: int = 50_000,
Expand Down Expand Up @@ -115,18 +114,14 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return select_events(returned_data, selector=slice(None, size))

def _generate_bunch(self, rng: RealNumberGenerator) -> tuple[DataSample, float]:
domain_generator = self.__domain_generator
if isinstance(domain_generator, WeightedDataGenerator):
domain, weights = domain_generator.generate(self.__bunch_size, rng)
else:
domain = _generate_without_progress_bar(
domain_generator, self.__bunch_size, rng
)
weights = 1 # type: ignore[assignment]
domain = _generate_without_progress_bar(
self.__domain_generator, self.__bunch_size, rng
)
transformed_domain = self.__domain_transformer(domain)
computed_intensities = self.__function(transformed_domain)
max_intensity: float = np.max(computed_intensities)
random_intensities = rng(size=self.__bunch_size, max_value=max_intensity)
weights = domain.get("weights", 1)
hit_and_miss_sample = select_events(
domain,
selector=weights * computed_intensities > random_intensities,
Expand All @@ -139,9 +134,9 @@ def _generate_without_progress_bar(
) -> DataSample:
# https://github.com/ComPWA/tensorwaves/issues/395
show_progress = getattr(domain_generator, "show_progress", None)
if show_progress:
if show_progress is not None:
domain_generator.show_progress = False # type: ignore[attr-defined]
domain = domain_generator.generate(bunch_size, rng)
if show_progress:
if show_progress is not None:
domain_generator.show_progress = show_progress # type: ignore[attr-defined]
return domain
39 changes: 22 additions & 17 deletions src/tensorwaves/data/phasespace.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
# pylint: disable=import-outside-toplevel
"""Implementations of `.DataGenerator` and `.WeightedDataGenerator`."""
"""Implementations of a `.DataGenerator` for four-momentum samples."""
from __future__ import annotations

import logging
from typing import Mapping

import numpy as np
from tqdm.auto import tqdm

from tensorwaves.function._backend import raise_missing_module_error
from tensorwaves.interface import (
DataGenerator,
DataSample,
RealNumberGenerator,
WeightedDataGenerator,
)
from tensorwaves.interface import DataGenerator, DataSample, RealNumberGenerator

from ._data_sample import (
finalize_progress_bar,
Expand Down Expand Up @@ -64,20 +58,30 @@ def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
)
momentum_pool: DataSample = {}
while get_number_of_events(momentum_pool) < size:
phsp_momenta, weights = self.__phsp_generator.generate(
self.__bunch_size, rng
)
phsp_momenta = self.__phsp_generator.generate(self.__bunch_size, rng)
weights = phsp_momenta.get("weights")
if weights is None:
raise ValueError(
"DataSample returned by"
f" {type(self.__phsp_generator).__name__} doesn't contain"
' "weights"'
)
hit_and_miss_randoms = rng(self.__bunch_size)
bunch = select_events(phsp_momenta, selector=weights > hit_and_miss_randoms)
momentum_pool = merge_events(momentum_pool, bunch)
progress_bar.update(n=get_number_of_events(bunch))
finalize_progress_bar(progress_bar)
return select_events(momentum_pool, selector=slice(None, size))
phsp = select_events(momentum_pool, selector=slice(None, size))
del phsp["weights"]
return phsp


class TFWeightedPhaseSpaceGenerator(WeightedDataGenerator):
class TFWeightedPhaseSpaceGenerator(DataGenerator):
"""Implements a phase space generator **with weights** using tensorflow.
The weights are provided in the returned `.DataSample` under the key
:code:`"weights"`.
Args:
initial_state_mass: Mass of the decaying state.
final_state_masses: A mapping of final state IDs to the corresponding masses.
Expand All @@ -102,9 +106,7 @@ def __init__(
names=list(map(str, sorted_ids)),
)

def generate(
self, size: int, rng: RealNumberGenerator
) -> tuple[DataSample, np.ndarray]:
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
r"""Generate a `.DataSample` of phase space four-momenta with weights.
Returns:
Expand All @@ -122,4 +124,7 @@ def generate(
f"p{label}": momenta.numpy()[:, [3, 0, 1, 2]]
for label, momenta in particles.items()
}
return phsp_momenta, weights.numpy()
return {
"weights": weights.numpy(),
**phsp_momenta,
}
5 changes: 4 additions & 1 deletion src/tensorwaves/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def __init__( # pylint: disable=too-many-arguments
backend: str = "numpy",
) -> None:
self.__data = {k: np.array(v) for k, v in data.items()}
self.__phsp = {k: np.array(v) for k, v in phsp.items()}
self.__phsp = {k: np.array(v) for k, v in phsp.items() if k != "weights"}
self.__phsp_weights = phsp.get("weights")
self.__function = function
self.__gradient = gradient_creator(self.__call__, backend)

Expand All @@ -207,6 +208,8 @@ def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
self.__function.update_parameters(parameters)
bare_intensities = self.__function(self.__data)
phsp_intensities = self.__function(self.__phsp)
if self.__phsp_weights is not None:
phsp_intensities *= self.__phsp_weights
normalization_factor = 1.0 / (
self.__phsp_volume * self.__mean_function(phsp_intensities)
)
Expand Down
12 changes: 1 addition & 11 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,17 +231,7 @@ class DataGenerator(ABC):

@abstractmethod
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
...


class WeightedDataGenerator(ABC):
"""Abstract class for generating a `.DataSample` with weights."""

@abstractmethod
def generate(
self, size: int, rng: RealNumberGenerator
) -> tuple[DataSample, np.ndarray]:
r"""Generate `.DataSample` with weights.
r"""Generate a `.DataSample` with :code:`size` events.
Returns:
A `tuple` of a `.DataSample` with an array of weights.
Expand Down
4 changes: 3 additions & 1 deletion tests/data/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: disable=import-outside-toplevel
from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
Expand Down Expand Up @@ -96,7 +98,7 @@ def test_generate_four_momenta_on_flat_distribution(self):
assert pytest.approx(phsp[i]) == data[i]


def test_generate_without_progress_bar(capsys: "CaptureFixture"):
def test_generate_without_progress_bar(capsys: CaptureFixture):
class SilentGenerator(DataGenerator):
def generate(self, size: int, rng: RealNumberGenerator) -> DataSample:
return {"x": 1} # type: ignore[dict-item]
Expand Down
5 changes: 4 additions & 1 deletion tests/data/test_phasespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,10 @@ def test_generate_deterministic(self, pdg: "ParticleCollection"):
i: pdg[name].mass for i, name in enumerate(final_state_names)
},
)
phsp_momenta, weights = phsp_generator.generate(sample_size, rng)
phsp_momenta = phsp_generator.generate(sample_size, rng)
assert list(phsp_momenta) == ["weights", "p0", "p1", "p2"]
weights = phsp_momenta.get("weights", [])
del phsp_momenta["weights"]
print("Expected values, get by running pytest with the -s flag")
pprint(
{
Expand Down

0 comments on commit 642e774

Please sign in to comment.