-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow finufft on linux #137
Comments
Working on a standalone script to demonstrate this issue. """
Demonstrate the behavior of finufft discussed here:
https://github.com/flatironinstitute/finufft/issues/235
https://github.com/ACRF-Image-X-Institute/mri_distortion_toolkit/issues/137
This script demonstrates this behavior in a self contained way.
"""
from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
def fiNufft_Ax(x):
"""
flatron instiute nufft
Returns A*x
equivalent to the 'notranpose' option in shanshans code
xj and yj are non uniform nonuniform source points. they are essentially the encoding signals.
sk and tk are uniform target points
# """
if x.dtype is not np.dtype('complex128'):
x = x.astype('complex128')
# y = nufft2d3(xj, yj, x, sk, tk, eps=1e-06, isign=-1)
y = Nufft_Ax_Plan.execute(x, None)
return y.flatten()
def fiNufft_Atb(x):
"""
flatron instiute nufft
This is to define the Nufft as a scipy.sparse.linalg.LinearOperator which can be used by the lsqr algorithm
see here for explanation:
https://stackoverflow.com/questions/48621407/python-equivalent-of-matlabs-lsqr-with-first-argument-a-function
Returns A'*x
equivalent to the 'tranpose' option in shanshans code
"""
y = Nufft_Atb_Plan.execute(x, None)
return y.flatten()
# set up random image
StartingImage = np.random.rand(148, 148)
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])
# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()
'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
Nufft_Ax_Plan = Plan(3, 2, 1, 1e-06, -1)
Nufft_Ax_Plan.setpts(xj, yj, None, sk, tk)
Nufft_Atb_Plan = Plan(3, 2, 1, 1e-06, 1)
Nufft_Atb_Plan.setpts(sk, tk, None, xj, yj)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=fiNufft_Ax, rmatvec=fiNufft_Atb)
StartingImage = StartingImage.flatten().astype(complex)
maxit = 20
time = []
for i in range(10):
_start_time = perf_counter()
x1 = lsqr(A, fk1, iter_lim=maxit, x0=StartingImage)
time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} u\u00B1 {np.std(time): 1.2f}s') |
|
example which does not use finufft: """
same example but without any finufft dependency
"""
# from finufft import Plan
import numpy as np
from scipy.sparse.linalg import lsqr
from scipy.sparse.linalg import LinearOperator
from scipy.fft import fft2
from scipy.fft import fftshift
from time import perf_counter
from pynufft import NUFFT, helper
def PyNufft_Ax(x):
y = Nufft_Ax_Plan.solve(x, solver='cg', maxiter=50)
return y
def PyNufft_Atb(x):
y = Nufft_Ax_Plan.adjoint(x)
return y
# set up random image
StartingImage = np.random.rand(148, 148)
Nrows = StartingImage.shape[0]
Ncolumns = StartingImage.shape[0]
StartingImage[StartingImage > 0.5] = 100
k_space = fftshift(fft2(fftshift(StartingImage)))
fk1 = np.reshape(k_space, StartingImage.shape[0] * StartingImage.shape[1])
# set up indices
x_lin_size, y_lin_size = (148, 148)
xn_lin = np.linspace(-x_lin_size / 2, -x_lin_size / 2 + x_lin_size - 1, x_lin_size)
yn_lin = np.linspace(-y_lin_size / 2, -y_lin_size / 2 + y_lin_size - 1, y_lin_size)
[xn_lin, yn_lin] = np.meshgrid(xn_lin, yn_lin, indexing='ij')
xn_lin = xn_lin.flatten()
yn_lin = yn_lin.flatten()
'''
the following is just a very hacky way to get some distorted indices
which somehwat resemble the real case
'''
xj = (xn_lin + np.sin(xn_lin)*5)*10
yj = (yn_lin + np.sin(yn_lin)*5)*10
sk = xn_lin / x_lin_size
tk = yn_lin / y_lin_size
yn_dis = yj/(2*np.pi)
xn_dis = xj/(2*np.pi)
#instantiate plan
# Nufft_Ax_Plan = NUFFT(helper.device_list()[0])
Nufft_Ax_Plan = NUFFT()
Kx_dis_pytorch = np.reshape(xn_dis,[Nrows, Ncolumns]) / Nrows * 2 * np.pi # [-pi, pi]
Ky_dis_pytorch = np.reshape(yn_dis,[Nrows, Ncolumns]) / Ncolumns * 2 * np.pi
indede = 0
k_xy_dis = np.zeros([Nrows * Ncolumns, 2])
for i in range(Nrows):
for j in range(Ncolumns):
k_xy_dis[indede, 0] = Kx_dis_pytorch[i, j]
k_xy_dis[indede, 1] = Ky_dis_pytorch[i, j]
indede = indede + 1
om = np.vstack([xj, yj])
Nd = StartingImage.shape # image size
Kd = k_space.shape # kspace size
Jd = (3, 3) # interpolation size
Nufft_Ax_Plan.plan(k_xy_dis, Nd, Kd, Jd)
A = LinearOperator((fk1.shape[0], fk1.shape[0]), matvec=PyNufft_Ax, rmatvec=PyNufft_Atb)
time = []
for i in range(10):
_start_time = perf_counter()
x1 = lsqr(A, fk1, iter_lim=20, x0=None)
time.append(perf_counter() - _start_time)
print(f'run time: {np.mean(time): 1.2f} \u00B1 {np.std(time): 1.2f}s') |
windows: 2.18 ± 0.18s |
def dummy_Ax(x):
sleep(.1)
y = x.reshape(148, 148) **2
return y
def dummmy_Atb(x):
sleep(.1)
y = np.sqrt(x.reshape(148, 148))
return y |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This script is to allow someone to reproduce the issue described here
set up environment and sample data
Run example
copy the below into a new file at MRI_DistortionQA root:
The text was updated successfully, but these errors were encountered: