-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
jupyter_utils.py
372 lines (302 loc) · 15.3 KB
/
jupyter_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using
Matplotlib produce common plots for metrics and images.
"""
from __future__ import annotations
import copy
from collections.abc import Callable, Mapping
from enum import Enum
from threading import RLock, Thread
from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from monai.utils import IgniteInfo
from monai.utils.module import min_version, optional_import
try:
import matplotlib.pyplot as plt
has_matplotlib = True
except ImportError:
has_matplotlib = False
if TYPE_CHECKING:
from ignite.engine import Engine, Events
else:
Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")
LOSS_NAME = "loss"
def plot_metric_graph(
ax: plt.Axes,
title: str,
graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],
yscale: str = "log",
avg_keys: tuple[str] = (LOSS_NAME,),
window_fraction: int = 20,
) -> None:
"""
Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap`
should be lists of (timepoint, value) pairs as stored in MetricLogger objects.
Args:
ax: Axes object to plot into
title: graph title
graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs
yscale: scale for y-axis compatible with `Axes.set_yscale`
avg_keys: tuple of keys in `graphmap` to provide running average plots for
window_fraction: what fraction of the graph value length to use as the running average window
"""
from matplotlib.ticker import MaxNLocator
for n, v in graphmap.items():
if len(v) > 0:
if isinstance(v[0], (tuple, list)): # values are (x,y) pairs
inds, vals = zip(*v) # separate values into list of indices in X dimension and values
else:
inds, vals = tuple(range(len(v))), tuple(v) # values are without indices, make indices for them
ax.plot(inds, vals, label=f"{n} = {vals[-1]:.5g}")
# if requested compute and plot a running average for the values using a fractional window size
if n in avg_keys and len(v) > window_fraction:
window = len(v) // window_fraction
kernel = np.ones((window,)) / window
ra = np.convolve((vals[0],) * (window - 1) + vals, kernel, mode="valid")
ax.plot(inds, ra, label=f"{n} Avg = {ra[-1]:.5g}")
ax.set_title(title)
ax.set_yscale(yscale)
ax.axis("on")
ax.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0.0)
ax.grid(True, "both", "both")
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
def plot_metric_images(
fig: plt.Figure,
title: str,
graphmap: Mapping[str, list[float] | tuple[list[float], list[float]]],
imagemap: dict[str, np.ndarray],
yscale: str = "log",
avg_keys: tuple[str] = (LOSS_NAME,),
window_fraction: int = 20,
) -> list:
"""
Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be
metrics from a training run and the images to be the batch and output from the last iteration. This uses
`plot_metric_graph` to plot the metric graph.
Args:
fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing
title: graph title
graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs
imagemap: dictionary of named images to show with metric plot
yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`
avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for
window_fraction: for metric plot, what fraction of the graph value length to use as the running average window
Returns:
list of Axes objects for graph followed by images
"""
gridshape = (4, max(1, len(imagemap)))
graph = plt.subplot2grid(gridshape, (0, 0), colspan=gridshape[1], fig=fig)
plot_metric_graph(graph, title, graphmap, yscale, avg_keys, window_fraction)
axes = [graph]
for i, n in enumerate(imagemap):
im = plt.subplot2grid(gridshape, (1, i), rowspan=2, fig=fig)
if imagemap[n].shape[0] == 3:
im.imshow(imagemap[n].transpose([1, 2, 0]))
else:
im.imshow(np.squeeze(imagemap[n]), cmap="gray")
im.set_title(f"{n}\n{imagemap[n].min():.3g} -> {imagemap[n].max():.3g}")
im.axis("off")
axes.append(im)
return axes
def tensor_to_images(name: str, tensor: torch.Tensor) -> np.ndarray | None:
"""
Return an tuple of images derived from the given tensor. The `name` value indices which key from the
output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors
instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is
color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show
each channel separately.
"""
if tensor.ndim == 3 and tensor.shape[1] > 2 and tensor.shape[2] > 2:
return tensor.cpu().data.numpy() # type: ignore[no-any-return]
if tensor.ndim == 4 and tensor.shape[2] > 2 and tensor.shape[3] > 2:
dmid = tensor.shape[1] // 2
return tensor[:, dmid].cpu().data.numpy() # type: ignore[no-any-return]
return None
def plot_engine_status(
engine: Engine,
logger: Any,
title: str = "Training Log",
yscale: str = "log",
avg_keys: tuple[str] = (LOSS_NAME,),
window_fraction: int = 20,
image_fn: Callable[[str, torch.Tensor], Any] | None = tensor_to_images,
fig: plt.Figure | None = None,
selected_inst: int = 0,
) -> tuple[plt.Figure, list]:
"""
Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics
taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are
converted to Numpy arrays suitable for input to `Axes.imshow` using `image_fn`, if this is None then no image
plotting is done.
Args:
engine: Engine to extract images from
logger: MetricLogger to extract loss and metric data from
title: graph title
yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`
avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for
window_fraction: for metric plot, what fraction of the graph value length to use as the running average window
image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot
fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing
selected_inst: index of the instance to show in the image plot
Returns:
Figure object (or `fig` if given), list of Axes objects for graph and images
"""
if fig is not None:
fig.clf()
else:
fig = plt.Figure(figsize=(20, 10), tight_layout=True, facecolor="white")
graphmap: dict[str, list[float]] = {LOSS_NAME: logger.loss}
graphmap.update(logger.metrics)
imagemap: dict = {}
if image_fn is not None and engine.state is not None and engine.state.batch is not None:
for src in (engine.state.batch, engine.state.output):
label = "Batch" if src is engine.state.batch else "Output"
batch_selected_inst = selected_inst # selected batch index, set to 0 when src is decollated
# if the src object is a list of elements, ie. a decollated batch, select an element and keep it as
# a dictionary of tensors with a batch dimension added
if isinstance(src, list):
selected_dict = src[selected_inst] # select this element
batch_selected_inst = 0 # set the selection to be the single index in the batch dimension
# store each tensor that is interpretable as an image with an added batch dimension
src = {k: v[None] for k, v in selected_dict.items() if isinstance(v, torch.Tensor) and v.ndim >= 3}
# images will be generated from the batch item selected above only, or from the single item given as `src`
if isinstance(src, dict):
for k, v in src.items():
if isinstance(v, torch.Tensor) and v.ndim >= 4:
image = image_fn(k, v[batch_selected_inst])
# if we have images add each one separately to the map
if image is not None:
for i, im in enumerate(image):
imagemap[f"{k}_{i}"] = im
elif isinstance(src, torch.Tensor):
image = image_fn(label, src)
if image is not None:
imagemap[f"{label}_{i}"] = image
axes = plot_metric_images(fig, title, graphmap, imagemap, yscale, avg_keys, window_fraction)
if logger.loss:
axes[0].axhline(logger.loss[-1][1], c="k", ls=":") # draw dotted horizontal line at last loss value
return fig, axes
def _get_loss_from_output(
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
) -> torch.Tensor:
"""Returns a single value from the network output, which is a dict or tensor."""
def _get_loss(data: torch.Tensor | dict[str, torch.Tensor]) -> torch.Tensor:
if isinstance(data, dict):
return data["loss"]
return data
if isinstance(output, list):
return _get_loss(output[0])
return _get_loss(output)
class StatusMembers(Enum):
"""
Named members of the status dictionary, others may be present for named metric values.
"""
STATUS = "Status"
EPOCHS = "Epochs"
ITERS = "Iters"
LOSS = "Loss"
class ThreadContainer(Thread):
"""
Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This
allows an engine to begin a run in the background and allow the starting notebook cell to complete. A
user can thus start a run and then navigate away from the notebook without concern for loosing connection
with the running cell. All output is acquired through methods which synchronize with the running engine
using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented
from starting the next iteration.
Args:
engine: wrapped `Engine` object, when the container is started its `run` method is called
loss_transform: callable to convert an output dict into a single numeric value
metric_transform: callable to convert a named metric value into a single numeric value
status_format: format string for status key-value pairs.
"""
def __init__(
self,
engine: Engine,
loss_transform: Callable = _get_loss_from_output,
metric_transform: Callable = lambda name, value: value,
status_format: str = "{}: {:.4}",
):
super().__init__()
self.lock = RLock()
self.engine = engine
self._status_dict: dict[str, Any] = {}
self.loss_transform = loss_transform
self.metric_transform = metric_transform
self.fig: plt.Figure | None = None
self.status_format = status_format
self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._update_status)
def run(self):
"""Calls the `run` method of the wrapped engine."""
self.engine.run()
def stop(self):
"""Stop the engine and join the thread."""
self.engine.terminate()
self.join()
def _update_status(self):
"""Called as an event, updates the internal status dict at the end of iterations."""
with self.lock:
state = self.engine.state
stats: dict[str, Any] = {
StatusMembers.EPOCHS.value: 0,
StatusMembers.ITERS.value: 0,
StatusMembers.LOSS.value: float("nan"),
}
if state is not None:
if state.max_epochs is not None and state.max_epochs >= 1:
epoch = f"{state.epoch}/{state.max_epochs}"
else:
epoch = str(state.epoch)
if state.epoch_length is not None:
iters = f"{state.iteration % state.epoch_length}/{state.epoch_length}"
else:
iters = str(state.iteration)
stats[StatusMembers.EPOCHS.value] = epoch
stats[StatusMembers.ITERS.value] = iters
stats[StatusMembers.LOSS.value] = self.loss_transform(state.output)
metrics = state.metrics or {}
for m, v in metrics.items():
v = self.metric_transform(m, v)
if v is not None:
stats[m].append(v)
self._status_dict.update(stats)
@property
def status_dict(self) -> dict[str, str]:
"""A dictionary containing status information, current loss, and current metric values."""
with self.lock:
stats = {StatusMembers.STATUS.value: "Running" if self.is_alive() else "Stopped"}
stats.update(self._status_dict)
return stats
def status(self) -> str:
"""Returns a status string for the current state of the engine."""
stats = copy.deepcopy(self.status_dict)
msgs = [stats.pop(StatusMembers.STATUS.value), "Iters: " + str(stats.pop(StatusMembers.ITERS.value, 0))]
for key, val in stats.items():
if isinstance(val, float):
msg = self.status_format.format(key, val)
else:
msg = f"{key}: {val}"
msgs.append(msg)
return ", ".join(msgs)
def plot_status(self, logger: Any, plot_func: Callable = plot_engine_status) -> plt.Figure | None:
"""
Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`.
The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title,
`self.engine`, `logger`, and `self.fig` respectively. The return value must be a figure object (stored in
`self.fig`) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method,
which holds the internal lock during the plot generation.
"""
with self.lock:
self.fig, _ = plot_func(title=self.status(), engine=self.engine, logger=logger, fig=self.fig)
return self.fig