Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an AnalysisCollection class #4017

Open
wants to merge 9 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ Fixes
* Fix groups.py doctests using sphinx directives (Issue #3925, PR #4374)

Enhancements
* Add an `AnalaysisCollection` class to perform multiple analysis on the same
trajectory (#3569, PR #4017).
* Added a tqdm progress bar for `MDAnalysis.analysis.pca.PCA.transform()`
(PR #4531)
* Improved performance of PDBWriter (Issue #2785, PR #4472)
Expand Down
235 changes: 202 additions & 33 deletions package/MDAnalysis/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@
plt.xlabel("time (ps)")
plt.ylabel("RMSD (Å)")

If you want to run two or more different analyses on the same trajectory you
can also efficently combine them using the
:class:`MDAnalysis.analysis.base.AnalysisCollection` class.


Writing new analysis tools
--------------------------
Expand Down Expand Up @@ -121,6 +125,7 @@
tools if only the single-frame analysis function needs to be written.

"""
from typing import Optional, Tuple
from collections import UserDict
import inspect
import logging
Expand Down Expand Up @@ -220,7 +225,71 @@ def __setstate__(self, state):
self.data = state


class AnalysisBase(object):
def _run(
analysis_instances: Tuple["AnalysisBase", ...],
start: Optional[int] = None,
stop: Optional[int] = None,
step: Optional[int] = None,
frames: Optional[int] = None,
verbose: Optional[bool] = None,
progressbar_kwargs: Optional[dict] = None,
) -> None:
"""Implementation of common run method."""

# if verbose unchanged, use default of first analysis_instance
verbose = getattr(analysis_instances[0], '_verbose',
False) if verbose is None else verbose

logger.info("Choosing frames to analyze")
for analysis_object in analysis_instances:
analysis_object._setup_frames(
analysis_object._trajectory,
start=start,
stop=stop,
step=step,
frames=frames,
)

logger.info("Starting preparation")
for analysis_object in analysis_instances:
analysis_object._prepare()

if progressbar_kwargs is None:
progressbar_kwargs = {}

logger.info(
"Starting analysis loop over "
f"{analysis_instances[0].n_frames} trajectory frames"
)
for i, ts in enumerate(
ProgressBar(
analysis_instances[0]._sliced_trajectory,
verbose=verbose,
**progressbar_kwargs,
)
):
ts_original = ts.copy()

for analysis_object in analysis_instances:
# Set attributes before calling `_single_frame()`. Setting
# these attributes explicitly is mandatory so that each
# instance can access the information of the current timestep.
analysis_object._frame_index = i
analysis_object._ts = ts
analysis_object.frames[i] = ts.frame
analysis_object.times[i] = ts.time

# Call the actual analysis of each instance.
analysis_object._single_frame()

ts = ts_original

logger.info("Finishing up")
for analysis_object in analysis_instances:
analysis_object._conclude()


class AnalysisBase:
r"""Base class for defining multi-frame analysis

The class is designed as a template for creating multi-frame analyses.
Expand Down Expand Up @@ -333,8 +402,10 @@ def _setup_frames(self, trajectory, start=None, stop=None, step=None,
step : int, optional
number of frames to skip between each analysed frame
frames : array_like, optional
array of integers or booleans to slice trajectory; cannot be
combined with `start`, `stop`, `step`
array of integers or booleans to slice trajectory; `frames` can
only be used *instead* of `start`, `stop`, and `step`. Setting
*both* `frames` and at least one of `start`, `stop`, `step` to a
non-default value will raise a :exc:`ValueError`.

.. versionadded:: 2.2.0

Expand Down Expand Up @@ -389,8 +460,13 @@ def _conclude(self):
"""
pass # pylint: disable=unnecessary-pass

def run(self, start=None, stop=None, step=None, frames=None,
verbose=None, *, progressbar_kwargs=None):
def run(self,
start: Optional[int] = None,
stop: Optional[int] = None,
step: Optional[int] = None,
frames: Optional[int] = None,
verbose: Optional[bool] = None,
progressbar_kwargs: Optional[dict] = None) -> "AnalysisBase":
"""Perform the calculation

Parameters
Expand Down Expand Up @@ -426,34 +502,128 @@ def run(self, start=None, stop=None, step=None, frames=None,
Add `progressbar_kwargs` parameter,
allowing to modify description, position etc of tqdm progressbars
"""
logger.info("Choosing frames to analyze")
# if verbose unchanged, use class default
verbose = getattr(self, '_verbose',
False) if verbose is None else verbose

self._setup_frames(self._trajectory, start=start, stop=stop,
step=step, frames=frames)
logger.info("Starting preparation")
self._prepare()
logger.info("Starting analysis loop over %d trajectory frames",
self.n_frames)
if progressbar_kwargs is None:
progressbar_kwargs = {}

for i, ts in enumerate(ProgressBar(
self._sliced_trajectory,
verbose=verbose,
**progressbar_kwargs)):
self._frame_index = i
self._ts = ts
self.frames[i] = ts.frame
self.times[i] = ts.time
self._single_frame()
logger.info("Finishing up")
self._conclude()
_run(analysis_instances=(self,),
start=start,
stop=stop,
step=step,
frames=frames,
verbose=verbose,
progressbar_kwargs=progressbar_kwargs)

return self


class AnalysisCollection:
"""
Class for running a collection of analysis classes on a single trajectory.

Running a collection of analyses with ``AnalysisCollection`` can result in
a speedup compared to running the individual analyses since the trajectory
loop ins only performed once.

The class assumes that each analysis is a child of
:class:`MDAnalysis.analysis.base.AnalysisBase`. Additionally, the
trajectory of all `analysis_instances` must be the same.

By default, it is ensured that all analysis instances use the
*same original* timestep and not an altered one from a previous analysis
object.

Parameters
----------
*analysis_instances : tuple
List of analysis classes to run on the same trajectory.

Raises
------
AttributeError
If all the provided ``analysis_instances`` do not work on the same
trajectory.
AttributeError
If an ``analysis_object`` is not a child of
:class:`MDAnalysis.analysis.base.AnalysisBase`.

Example
-------
.. code-block:: python

import MDAnalysis as mda
from MDAnalysis.analysis.rdf import InterRDF
from MDAnalysis.analysis.base import AnalysisCollection
from MDAnalysisTests.datafiles import TPR, XTC

u = mda.Universe(TPR, XTC)

# Select atoms
ag_O = u.select_atoms("name O")
ag_H = u.select_atoms("name H")

# Create individual analysis instances
rdf_OO = InterRDF(ag_O, ag_O)
rdf_OH = InterRDF(ag_H, ag_H)

# Create a collection for common trajectory
collection = AnalysisCollection(rdf_OO, rdf_OH)

# Run the collected analysis
collection.run(start=0, stop=100, step=10)

# Results are stored in the individual instances
print(rdf_OO.results)
print(rdf_OH.results)


.. versionadded:: 2.8.0

"""

def __init__(self, *analysis_instances):
for analysis_object in analysis_instances:
if (
analysis_instances[0]._trajectory
!= analysis_object._trajectory
):
raise ValueError("`analysis_instances` do not have the same "
"trajectory.")
if not isinstance(analysis_object, AnalysisBase):
raise AttributeError(f"Analysis object {analysis_object} is "
"not a child of `AnalysisBase`.")

self._analysis_instances = analysis_instances

def run(self,
start: Optional[int] = None,
stop: Optional[int] = None,
step: Optional[int] = None,
frames: Optional[int] = None,
verbose: Optional[bool] = None,
progressbar_kwargs: Optional[dict] = None) -> None:
"""Perform the calculation

Parameters
----------
start : int, optional
start frame of analysis
stop : int, optional
stop frame of analysis
step : int, optional
number of frames to skip between each analysed frame
frames : array_like, optional
array of integers or booleans to slice trajectory; `frames` can
only be used *instead* of `start`, `stop`, and `step`. Setting
*both* `frames` and at least one of `start`, `stop`, `step` to a
non-default value will raise a :exc:`ValueError`.
verbose : bool, optional
Turn on verbosity
"""
_run(analysis_instances=self._analysis_instances,
start=start,
stop=stop,
step=step,
frames=frames,
verbose=verbose,
progressbar_kwargs=progressbar_kwargs)

class AnalysisFromFunction(AnalysisBase):
r"""Create an :class:`AnalysisBase` from a function working on AtomGroups

Expand Down Expand Up @@ -526,7 +696,7 @@ def __init__(self, function, trajectory=None, *args, **kwargs):

self.kwargs = kwargs

super(AnalysisFromFunction, self).__init__(trajectory)
super().__init__(trajectory)

def _prepare(self):
self.results.timeseries = []
Expand Down Expand Up @@ -594,8 +764,7 @@ def RotationMatrix(mobile, ref):

class WrapperClass(AnalysisFromFunction):
def __init__(self, trajectory=None, *args, **kwargs):
super(WrapperClass, self).__init__(function, trajectory,
*args, **kwargs)
super().__init__(function, trajectory, *args, **kwargs)

return WrapperClass

Expand Down
66 changes: 65 additions & 1 deletion testsuite/MDAnalysisTests/analysis/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import MDAnalysis as mda
from MDAnalysis.analysis import base
from MDAnalysis.analysis.rdf import InterRDF

from MDAnalysisTests.datafiles import PSF, DCD, TPR, XTC
from MDAnalysisTests.util import no_deprecated_call
Expand Down Expand Up @@ -147,6 +148,68 @@ def test_different_instances(self, results):
assert new_results.data is not results.data


class TestAnalysisCollection:
@pytest.fixture
def universe(self):
return mda.Universe(TPR, XTC)

def test_run(self, universe):
ag_O = universe.select_atoms("name O")
ag_H = universe.select_atoms("name H")

rdf_OO = InterRDF(ag_O, ag_O)
rdf_OH = InterRDF(ag_O, ag_H)

collection = base.AnalysisCollection(rdf_OO, rdf_OH)
collection.run(start=0, stop=100, step=10)

assert rdf_OO.results is not None
assert rdf_OH.results is not None

def test_trajectory_manipulation(self, universe):
"""Test that the timestep is the same for each analysis class."""
class CustomAnalysis(base.AnalysisBase):
"""Custom class that is shifting positions in every step by 10."""

def __init__(self, trajectory):
self._trajectory = trajectory

def _prepare(self):
pass

def _single_frame(self):
self._ts.positions += 10
self.ref_pos = self._ts.positions.copy()[0, 0]

ana_1 = CustomAnalysis(universe.trajectory)
ana_2 = CustomAnalysis(universe.trajectory)

collection = base.AnalysisCollection(ana_1, ana_2)
collection.run(frames=[0])

assert ana_2.ref_pos == ana_1.ref_pos

def test_inconsistent_trajectory(self, universe):
v = mda.Universe(TPR, XTC)

match = "`analysis_instances` do not have the same trajectory."
with pytest.raises(ValueError, match=match):
base.AnalysisCollection(
InterRDF(
universe.atoms, universe.atoms), InterRDF(v.atoms, v.atoms)
)

def test_no_base_child(self, universe):
class CustomAnalysis:
def __init__(self, trajectory):
self._trajectory = trajectory

match = "not a child of `AnalysisBase`"
# collection for common trajectory loop with inconsistent trajectory
with pytest.raises(AttributeError, match=match):
base.AnalysisCollection(CustomAnalysis(universe.trajectory))


class FrameAnalysis(base.AnalysisBase):
"""Just grabs frame numbers of frames it goes over"""

Expand Down Expand Up @@ -252,7 +315,8 @@ def test_frames_times(u_xtc):
assert an.n_frames == len(frames)
assert_equal(an.found_frames, frames)
assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
assert_allclose(an.times, frames*100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR)
assert_allclose(
an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR)


def test_verbose(u):
Expand Down
Loading