Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow finufft on linux #137

Open
bwheelz36 opened this issue Aug 19, 2022 · 6 comments
Open

Slow finufft on linux #137

bwheelz36 opened this issue Aug 19, 2022 · 6 comments

Comments

@bwheelz36
Copy link
Member

bwheelz36 commented Aug 19, 2022

This script is to allow someone to reproduce the issue described here

set up environment and sample data

git clone https://github.com/ACRF-Image-X-Institute/MRI_DistortionQA.git
cd MRI_DistortionQA/
git checkout distortion_correction
python3 -m venv venv
source venv/bin/activate
pip3 install -U pip
pip3 install -U setuptools
pip install -r dev_requirements.txt

# replace default finufft with compiled nufft
pip uninstall finufft
git clone https://github.com/flatironinstitute/finufft.git
cd finufft/
make test
make python
cd ..

# get sample data (300 Mb)
wget https://cloudstor.aarnet.edu.au/plus/s/Wm9vndV47u941JU/download
unzip download
rm download

Run example

copy the below into a new file at MRI_DistortionQA root:

from pathlib import Path
from MRI_DistortionQA.MarkerAnalysis import MarkerVolume
from MRI_DistortionQA.MarkerAnalysis import MatchedMarkerVolumes
from MRI_DistortionQA.FieldCalculation import ConvertMatchedMarkersToBz
from MRI_DistortionQA import calculate_harmonics
import numpy as np
from MRI_DistortionQA.K_SpaceCorrector import KspaceDistortionCorrector

# Data import
dis_data_loc = Path(r'MRI_distortion_QA_sample_data/MR/04 gre_trans_AP_330')
gt_data_loc = Path(r'MRI_distortion_QA_sample_data/CT/slicer_centroids.mrk.json')

# extract markers:
gt_volume = MarkerVolume(gt_data_loc, r_max=300)
dis_volume = MarkerVolume(dis_data_loc, n_markers_expected=336, iterative_segmentation=True)
# match markers:
matched_volume = MatchedMarkerVolumes(gt_volume, dis_volume, n_refernce_markers=11)
# calculate fields
B_fields = ConvertMatchedMarkersToBz(matched_volume.MatchedCentroids, dis_volume.dicom_data)
# calculate harmonics
gradient_strength = np.array(dis_volume.dicom_data['gradient_strength'])
normalisation_factor = [1 / gradient_strength[0], 1 / gradient_strength[1], 1 / gradient_strength[2],
                        1]  # this normalised gradient harmonics to 1mT/m
# normalisation_factor = [1,1,1,1]
G_x_Harmonics, G_y_Harmonics, G_z_Harmonics, B0_Harmonics = calculate_harmonics(B_fields.MagneticFields,
                                                                                n_order=8,
                                                                                norm=normalisation_factor)

# correct input images
GDC = KspaceDistortionCorrector(ImageDirectory=dis_data_loc.resolve(),
                                Gx_Harmonics=G_x_Harmonics.harmonics,
                                Gy_Harmonics=G_y_Harmonics.harmonics,
                                Gz_Harmonics=G_z_Harmonics.harmonics,
                                ImExtension='dcm',
                                dicom_data=dis_volume.dicom_data,
                                correct_through_plane=False)
GDC.correct_all_images()
@bwheelz36
Copy link
Member Author

bwheelz36 commented Oct 24, 2022

Working on a standalone script to demonstrate this issue.
the following results in 1.7 s on linux, which is similar to the results I see for the real application (above) following the manual installation of fftw

"""
Demonstrate the behavior of finufft discussed here:
https://github.com/flatironinstitute/finufft/issues/235
https://github.com/ACRF-Image-X-Institute/mri_distortion_toolkit/issues/137

This script demonstrates this behavior in a self contained way.

"""
from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter


def fiNufft_Ax(x):
    """
    flatron instiute nufft
    Returns A*x
    equivalent to the 'notranpose' option in shanshans code
    xj and yj are non uniform nonuniform source points. they are essentially the encoding signals.
    sk and tk are uniform target points
    # """
    if x.dtype is not np.dtype('complex128'):
        x = x.astype('complex128')
    # y = nufft2d3(xj, yj, x, sk, tk, eps=1e-06, isign=-1)
    y = Nufft_Ax_Plan.execute(x, None)
    return y.flatten()


def fiNufft_Atb(x):
    """
    flatron instiute nufft
    This is to define the Nufft as a scipy.sparse.linalg.LinearOperator which can be used by the lsqr algorithm
    see here for explanation:
    https://stackoverflow.com/questions/48621407/python-equivalent-of-matlabs-lsqr-with-first-argument-a-function
    Returns A'*x
    equivalent to the 'tranpose' option in shanshans code
    """
    y = Nufft_Atb_Plan.execute(x, None)
    return y.flatten()


# set up random image
StartingImage = np.random.rand(148, 148)
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])

# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()

'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size

Nufft_Ax_Plan = Plan(3, 2, 1, 1e-06, -1)
Nufft_Ax_Plan.setpts(xj, yj, None, sk, tk)
Nufft_Atb_Plan = Plan(3, 2, 1, 1e-06, 1)
Nufft_Atb_Plan.setpts(sk, tk, None, xj, yj)

A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=fiNufft_Ax, rmatvec=fiNufft_Atb)
StartingImage = StartingImage.flatten().astype(complex)
maxit = 20
time = []
for i in range(10):
    _start_time = perf_counter()
    x1 = lsqr(A, fk1, iter_lim=maxit, x0=StartingImage)
    time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} u\u00B1 {np.std(time): 1.2f}s')

@bwheelz36
Copy link
Member Author

bwheelz36 commented Oct 24, 2022

Case Run time (s)
Linux: default fftw, pip installed finufft 3.4 ± 0.33
Linux: built fftw, pip installed finufft 1.23 ± 0.16
Linux: built fftw, built finufft 1.44 ± 0.22
Windows 0.45 ± 0.02s

@bwheelz36
Copy link
Member Author

example which does not use finufft:

"""
same example but without any finufft dependency

"""
# from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
from pynufft import NUFFT, helper


def PyNufft_Ax(x):
    y = Nufft_Ax_Plan.solve(x, solver='cg', maxiter=50)
    return y

def PyNufft_Atb(x):
    y = Nufft_Ax_Plan.adjoint(x)
    return y


# set up random image
StartingImage = np.random.rand(148, 148)
Nrows = StartingImage.shape[0]
Ncolumns = StartingImage.shape[0]
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])


# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()

'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
yn_dis = yj/(2*np.pi)
xn_dis = xj/(2*np.pi)

#instantiate plan
# Nufft_Ax_Plan = NUFFT(helper.device_list()[0])
Nufft_Ax_Plan = NUFFT()

Kx_dis_pytorch = np.reshape(xn_dis,[Nrows, Ncolumns]) / Nrows * 2 * np.pi  # [-pi, pi]
Ky_dis_pytorch = np.reshape(yn_dis,[Nrows, Ncolumns]) / Ncolumns * 2 * np.pi

indede = 0
k_xy_dis = np.zeros([Nrows * Ncolumns, 2])
for i in range(Nrows):
    for j in range(Ncolumns):
        k_xy_dis[indede, 0] = Kx_dis_pytorch[i, j]
        k_xy_dis[indede, 1] = Ky_dis_pytorch[i, j]
        indede = indede + 1

om = np.vstack([xj, yj])
Nd = StartingImage.shape  # image size
Kd = k_space.shape  # kspace size
Jd = (3, 3)  # interpolation size
Nufft_Ax_Plan.plan(k_xy_dis, Nd, Kd, Jd)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=PyNufft_Ax, rmatvec=PyNufft_Atb)
time = []
for i in range(10):
    _start_time = perf_counter()
    x1 = lsqr(A, fk1, iter_lim=20, x0=None)
    time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} \u00B1 {np.std(time): 1.2f}s')

@bwheelz36
Copy link
Member Author

linux_profile.txt

@bwheelz36
Copy link
Member Author

bwheelz36 commented Oct 31, 2022

windows: 2.18 ± 0.18s
linux: 4.23 ± 1.36s

@bwheelz36
Copy link
Member Author

def dummy_Ax(x):
    sleep(.1)
    y = x.reshape(148, 148) **2
    return y

def dummmy_Atb(x):
    sleep(.1)
    y = np.sqrt(x.reshape(148, 148))
    return y

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant