Skip to content

Commit

Permalink
adapted ALS (EVP) to complex cases
Browse files Browse the repository at this point in the history
  • Loading branch information
PGelss authored Jul 24, 2024
1 parent d180a15 commit 087fd35
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions scikit_tt/solvers/evp.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,19 +278,18 @@ def __construct_left_stacks(i: int, trains, stacks):
else:

# contract previous stack element with solution and operator cores
stacks.op_left[i] = np.tensordot(stacks.op_left[i - 1], trains.solution.cores[i - 1][:, :, 0, :], axes=(0, 0))
stacks.op_left[i] = np.tensordot(stacks.op_left[i - 1], np.conjugate(trains.solution.cores[i - 1][:, :, 0, :]), axes=(0, 0))
stacks.op_left[i] = np.tensordot(stacks.op_left[i], trains.operator.cores[i - 1], axes=([0, 2], [0, 2]))
stacks.op_left[i] = np.tensordot(stacks.op_left[i], trains.solution.cores[i - 1][:, :, 0, :], axes=([0, 2], [0, 1]))

if trains.operator_gevp is not None:
stacks.op_gevp_left[i] = np.tensordot(stacks.op_gevp_left[i - 1], trains.solution.cores[i - 1][:, :, 0, :], axes=(0, 0))
stacks.op_gevp_left[i] = np.tensordot(stacks.op_gevp_left[i - 1], np.conjugate(trains.solution.cores[i - 1][:, :, 0, :]), axes=(0, 0))
stacks.op_gevp_left[i] = np.tensordot(stacks.op_gevp_left[i], trains.operator_gevp.cores[i - 1], axes=([0, 2], [0, 2]))
stacks.op_gevp_left[i] = np.tensordot(stacks.op_gevp_left[i], trains.solution.cores[i - 1][:, :, 0, :], axes=([0, 2], [0, 1]))

for j in range(len(trains.previous)):
stacks.previous_left[j][i] = np.tensordot(stacks.previous_left[j][i - 1], trains.previous[j].cores[i - 1][:, :, 0, :], axes=(0, 0))
stacks.previous_left[j][i] = np.tensordot(stacks.previous_left[j][i], trains.solution.cores[i - 1][:, :, 0, :], axes=([0, 1], [0, 1]))

stacks.previous_left[j][i] = np.tensordot(stacks.previous_left[j][i], np.conjugate(trains.solution.cores[i - 1][:, :, 0, :]), axes=([0, 1], [0, 1]))


def __construct_right_stacks(i: int, trains, stacks):
Expand Down Expand Up @@ -321,17 +320,17 @@ def __construct_right_stacks(i: int, trains, stacks):
else:

# contract previous stack element with solution and operator cores
stacks.op_right[i] = np.tensordot(trains.solution.cores[i + 1][:, :, 0, :], stacks.op_right[i + 1], axes=(2, 2))
stacks.op_right[i] = np.tensordot(np.conjugate(trains.solution.cores[i + 1][:, :, 0, :]), stacks.op_right[i + 1], axes=(2, 2))
stacks.op_right[i] = np.tensordot(trains.operator.cores[i + 1], stacks.op_right[i], axes=([1, 3], [1, 3]))
stacks.op_right[i] = np.tensordot(trains.solution.cores[i + 1][:, :, 0, :], stacks.op_right[i], axes=([1, 2], [1, 3]))

if trains.operator_gevp is not None:
stacks.op_gevp_right[i] = np.tensordot(trains.solution.cores[i + 1][:, :, 0, :], stacks.op_gevp_right[i + 1], axes=(2, 2))
stacks.op_gevp_right[i] = np.tensordot(np.conjugate(trains.solution.cores[i + 1][:, :, 0, :]), stacks.op_gevp_right[i + 1], axes=(2, 2))
stacks.op_gevp_right[i] = np.tensordot(trains.operator_gevp.cores[i + 1], stacks.op_gevp_right[i], axes=([1, 3], [1, 3]))
stacks.op_gevp_right[i] = np.tensordot(trains.solution.cores[i + 1][:, :, 0, :], stacks.op_gevp_right[i], axes=([1, 2], [1, 3]))

for j in range(len(trains.previous)):
stacks.previous_right[j][i] = np.tensordot(trains.solution.cores[i + 1][:, :, 0, :], stacks.previous_right[j][i + 1], axes=(2, 1))
stacks.previous_right[j][i] = np.tensordot(np.conjugate(trains.solution.cores[i + 1][:, :, 0, :]), stacks.previous_right[j][i + 1], axes=(2, 1))
stacks.previous_right[j][i] = np.tensordot(trains.previous[j].cores[i + 1][:, :, 0, :], stacks.previous_right[j][i], axes=([1, 2], [1, 2]))


Expand Down Expand Up @@ -379,7 +378,7 @@ def __construct_micro_matrices(i: int, trains, stacks, shift: float) -> np.ndarr
tmp = np.tensordot(tmp, stacks.previous_right[j][i], axes=(2, 0))
tmp = tmp.reshape(trains.solution.ranks[i] * trains.previous[j].row_dims[i] * trains.solution.ranks[i + 1], 1)

micro_op += shift*tmp.dot(tmp.T)
micro_op += shift*tmp.dot(np.conjugate(tmp.T))

return micro_op, micro_op_gevp

Expand Down Expand Up @@ -435,13 +434,13 @@ def __update_core(i: int, micro_op: np.ndarray,
if solver == 'eigh':
eigenvalues, eigenvectors = lin.eigh(micro_op, b=micro_op_gevp, overwrite_a=True, overwrite_b=True,
check_finite=False,
subset_by_index=(micro_op.shape[0] - number_ev, micro_op.shape[0] - 1))
eigvals=(micro_op.shape[0] - number_ev, micro_op.shape[0] - 1))
eigenvalues = eigenvalues[::-1]
eigenvectors = eigenvectors[:, ::-1]

if real is True:
eigenvalues = np.real(eigenvalues)
eigenvectors = np.real(eigenvectors)
#eigenvectors = np.real(eigenvectors)

# reshape solution and orthonormalization
# ---------------------------------------
Expand Down

0 comments on commit 087fd35

Please sign in to comment.