Skip to content

Commit

Permalink
LSQML: first replicate ML
Browse files Browse the repository at this point in the history
  • Loading branch information
jfowkes committed Sep 11, 2024
1 parent 10f8c49 commit ab70cdc
Showing 1 changed file with 70 additions and 66 deletions.
136 changes: 70 additions & 66 deletions ptypy/custom/LSQML.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def __init__(self, ptycho_parent, pars=None):
self.pr_tmin = None

# Other
self.tmin = None # for ML step only
self.ML_model = None

self.ptycho.citations.add_article(
Expand Down Expand Up @@ -118,6 +119,9 @@ def engine_initialize(self):
self.pr_buf = self.pr.copy(self.pr.ID + '_buf', fill=0.)
self.pr_tmin = {} # need scalar per named pod

# for ML step only
self.tmin = 1.

self._initialize_model()

def _initialize_model(self):
Expand Down Expand Up @@ -176,11 +180,17 @@ def engine_iterate(self, num=1):
self.ob_grad << new_ob_grad
self.pr_grad << new_pr_grad

# Next conjugate
self.ob_h *= bt
self.ob_h -= self.ob_grad
self.pr_h *= bt
# # Next conjugate
# self.ob_h *= bt
# self.ob_h -= self.ob_grad
# self.pr_h *= bt
# self.ob_h -= self.ob_grad

This comment has been minimized.

Copy link
@pierrethibault

pierrethibault Sep 11, 2024

Member

ob -> pr?

This comment has been minimized.

Copy link
@jfowkes

jfowkes Sep 11, 2024

Author Collaborator

Ahh... yes that's probably where it's going wrong!


# Next ML conjugate
self.ob_h *= bt / self.tmin
self.ob_h -= self.ob_grad
self.pr_h *= bt / self.tmin
self.pr_h -= self.pr_grad

# ############################
# # Compute steepest descent #
Expand All @@ -190,20 +200,20 @@ def engine_iterate(self, num=1):
# self.pr_h << new_pr_grad
# self.pr_h *= -1

##########################
# Average direction (25) #
##########################
self.ob_nrm += 1e-6
self.ob_h /= self.ob_nrm
self.pr_nrm += 1e-6
self.pr_h /= self.pr_nrm
# ##########################
# # Average direction (25) #
# ##########################
# self.ob_nrm += 1e-6
# self.ob_h /= self.ob_nrm
# self.pr_nrm += 1e-6
# self.pr_h /= self.pr_nrm

########################
# Compute step lengths #
########################
t2 = time.time()
self.ML_model.compute_step_lengths()
ts += time.time() - t2
# ########################
# # Compute step lengths #
# ########################
# t2 = time.time()
# self.ML_model.compute_step_lengths()
# ts += time.time() - t2

# ################################
# # Take weighted mean step (27) #
Expand Down Expand Up @@ -235,9 +245,9 @@ def engine_iterate(self, num=1):
################
t3 = time.time()
B = self.ML_model.poly_line_coeffs()
tmin = self.ptycho.FType(-.5 * B[1] / B[2])
self.ob_h *= tmin
self.pr_h *= tmin
self.tmin = self.ptycho.FType(-.5 * B[1] / B[2])
self.ob_h *= self.tmin
self.pr_h *= self.tmin
self.ob += self.ob_h
self.pr += self.pr_h
tu += time.time() - t3
Expand Down Expand Up @@ -347,9 +357,45 @@ def new_grad(self):

def compute_step_lengths(self):
"""
Compute optimization step lengths according to the noise model.
Compute LSQML optimization step lengths according.
"""
raise NotImplementedError

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Third pod loop: calculate real-space step lengths
for name, pod in diff_view.pods.items():
if not pod.active:
continue

# Get xi
xi = pod.exit

# Get update directions
ob_h = self.ob_h[pod.ob_view]
pr_h = self.pr_h[pod.pr_view]

# Compute cross-terms
ob_h_pr = ob_h * pod.probe
pr_h_ob = pr_h * pod.object

# Calculate real-space step lengths (22)
M = np.zeros((2,2), dtype=np.cdouble)
rhs = np.zeros(2, dtype=np.double)
M[0,0] = np.sum(u.abs2(ob_h_pr)) + 1e-6
M[1,1] = np.sum(u.abs2(pr_h_ob)) + 1e-6
M[0,1] = np.sum(ob_h_pr * pr_h_ob.conj())
M[1,0] = np.sum(ob_h_pr.conj() * pr_h_ob)
rhs[0] = np.sum(np.real(xi * ob_h_pr.conj()))
rhs[1] = np.sum(np.real(xi * pr_h_ob.conj()))
#self.ob_tmin[name], self.pr_tmin[name] = np.linalg.solve(M, rhs)
self.ob_tmin[name], self.pr_tmin[name] = sp.linalg.solve(M, rhs, assume_a='her')

# # Calculate approx real-space step lengths (23)
# self.ob_tmin[name] = np.sum(np.real(xi * ob_h_pr.conj())) / (np.sum(u.abs2(ob_h_pr)) + 1e-6)
# self.pr_tmin[name] = np.sum(np.real(xi * pr_h_ob.conj())) / (np.sum(u.abs2(pr_h_ob)) + 1e-6)

def new_step(self):
"""
Expand Down Expand Up @@ -421,8 +467,8 @@ def __init__(self, MLengine):
for name, di_view in self.di.views.items():
if not di_view.active:
continue
self.weights[di_view] = (self.Irenorm * di_view.pod.ma_view.data
/ (1./self.Irenorm + di_view.data))
self.weights[di_view] = (di_view.pod.ma_view.data
/ (1 + di_view.data))

def __del__(self):
"""
Expand Down Expand Up @@ -708,48 +754,6 @@ def new_grad(self):

return error_dct

def compute_step_lengths(self):
"""
Compute optimization step lengths according to a Euclidean noise model.
"""

# Outer loop: through diffraction patterns
for dname, diff_view in self.di.views.items():
if not diff_view.active:
continue

# Third pod loop: calculate real-space step lengths
for name, pod in diff_view.pods.items():
if not pod.active:
continue

# Get xi
xi = pod.exit

# Get update directions
ob_h = self.ob_h[pod.ob_view]
pr_h = self.pr_h[pod.pr_view]

# Compute cross-terms
ob_h_pr = ob_h * pod.probe
pr_h_ob = pr_h * pod.object

# Calculate real-space step lengths (22)
M = np.zeros((2,2), dtype=np.cdouble)
rhs = np.zeros(2, dtype=np.double)
M[0,0] = np.sum(u.abs2(ob_h_pr)) + 1e-6
M[1,1] = np.sum(u.abs2(pr_h_ob)) + 1e-6
M[0,1] = np.sum(ob_h_pr * pr_h_ob.conj())
M[1,0] = np.sum(ob_h_pr.conj() * pr_h_ob)
rhs[0] = np.sum(np.real(xi * ob_h_pr.conj()))
rhs[1] = np.sum(np.real(xi * pr_h_ob.conj()))
#self.ob_tmin[name], self.pr_tmin[name] = np.linalg.solve(M, rhs)
self.ob_tmin[name], self.pr_tmin[name] = sp.linalg.solve(M, rhs, assume_a='her')

# # Calculate approx real-space step lengths (23)
# self.ob_tmin[name] = np.sum(np.real(xi * ob_h_pr.conj())) / (np.sum(u.abs2(ob_h_pr)) + 1e-6)
# self.pr_tmin[name] = np.sum(np.real(xi * pr_h_ob.conj())) / (np.sum(u.abs2(pr_h_ob)) + 1e-6)

def poly_line_coeffs(self):
"""
Compute the coefficients of the polynomial for line minimization
Expand Down

0 comments on commit ab70cdc

Please sign in to comment.