Skip to content

Commit

Permalink
Use unified memory for scalar indexing of permutation matrices (#313)
Browse files Browse the repository at this point in the history
Co-authored-by: Tim Besard <[email protected]>
  • Loading branch information
tgymnich and maleadt authored Oct 2, 2024
1 parent c8cf84a commit ff7c7eb
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions lib/mps/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ end

# Metal's pivoting sequence needs to be iterated sequentially...
# TODO: figure out a GPU-compatible way to get the permutation matrix
LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
LinearAlgebra.ipiv2perm(v::MtlVector, maxi::Integer) =
LinearAlgebra.ipiv2perm(Array(v), maxi)
LinearAlgebra.ipiv2perm(v::MtlVector{<:Any,MTL.CPUStorage}, maxi::Integer) =
LinearAlgebra.ipiv2perm(unsafe_wrap(Array, v), maxi)

@autoreleasepool function LinearAlgebra.lu(A::MtlMatrix{T};
check::Bool=true) where {T<:MtlFloat}
Expand All @@ -129,7 +131,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =
end

P = similar(A, UInt32, 1, min(N, M))
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)

commitAndContinue!(cmdbuf) do cbuf
mps_p = MPSMatrix(P)
Expand All @@ -150,7 +152,7 @@ LinearAlgebra.ipiv2perm(v::MtlVector{T}, maxi::Integer) where T =

wait_completed(cmdbuf)

status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
status = convert(LinearAlgebra.BlasInt, status[])
check && checknonsingular(status)

return LinearAlgebra.LU(B, p, status)
Expand Down Expand Up @@ -187,7 +189,7 @@ end
end

P = similar(A, UInt32, 1, min(N, M))
status = MtlArray{MPSMatrixDecompositionStatus}(undef)
status = MtlArray{MPSMatrixDecompositionStatus,0,SharedStorage}(undef)

commitAndContinue!(cmdbuf) do cbuf
mps_p = MPSMatrix(P)
Expand All @@ -205,7 +207,7 @@ end

wait_completed(cmdbuf)

status = convert(LinearAlgebra.BlasInt, Metal.@allowscalar status[])
status = convert(LinearAlgebra.BlasInt, status[])
check && _check_lu_success(status, allowsingular)

return LinearAlgebra.LU(A, p, status)
Expand Down

3 comments on commit ff7c7eb

@maleadt
Copy link
Member

@maleadt maleadt commented on ff7c7eb Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/116466

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.4.0 -m "<description of version>" ff7c7ebdcd5f6d513614e3a27062d06a07ce81f7
git push origin v1.4.0

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Metal Benchmarks

Benchmark suite Current: ff7c7eb Previous: 71b784e Ratio
private array/construct 26687.5 ns 23715.25 ns 1.13
private array/broadcast 465979.5 ns 474145.5 ns 0.98
private array/random/randn/Float32 993270.5 ns 994125 ns 1.00
private array/random/randn!/Float32 632166.5 ns 644458.5 ns 0.98
private array/random/rand!/Int64 568500 ns 569958 ns 1.00
private array/random/rand!/Float32 583500 ns 606250 ns 0.96
private array/random/rand/Int64 880458 ns 831750 ns 1.06
private array/random/rand/Float32 844333.5 ns 897625 ns 0.94
private array/copyto!/gpu_to_gpu 614333 ns 660666 ns 0.93
private array/copyto!/cpu_to_gpu 739479 ns 555208 ns 1.33
private array/copyto!/gpu_to_cpu 599208 ns 709417 ns 0.84
private array/accumulate/1d 1447750.5 ns 1430125 ns 1.01
private array/accumulate/2d 1496375 ns 1499500 ns 1.00
private array/iteration/findall/int 2263917 ns 2210520.5 ns 1.02
private array/iteration/findall/bool 1989875 ns 2041209 ns 0.97
private array/iteration/findfirst/int 1678000 ns 1704833 ns 0.98
private array/iteration/findfirst/bool 1663625 ns 1645334 ns 1.01
private array/iteration/scalar 2393834 ns 2430625 ns 0.98
private array/iteration/logical 3431520.5 ns 3432895.5 ns 1.00
private array/iteration/findmin/1d 1794125 ns 1763667 ns 1.02
private array/iteration/findmin/2d 1403416 ns 1353479 ns 1.04
private array/reductions/reduce/1d 805792 ns 730853.5 ns 1.10
private array/reductions/reduce/2d 704146 ns 709708 ns 0.99
private array/reductions/mapreduce/1d 815812.5 ns 800041 ns 1.02
private array/reductions/mapreduce/2d 716666.5 ns 713125 ns 1.00
private array/permutedims/4d 943959 ns 949333 ns 0.99
private array/permutedims/2d 938875 ns 930958 ns 1.01
private array/permutedims/3d 1005416.5 ns 1018708.5 ns 0.99
private array/copy 862875 ns 582583 ns 1.48
latency/precompile 4407793041 ns 4403995333 ns 1.00
latency/ttfp 6915521687.5 ns 6895957979 ns 1.00
latency/import 726643917 ns 723655188 ns 1.00
integration/metaldevrt 749270.5 ns 757604 ns 0.99
integration/byval/slices=1 1557959 ns 1623541 ns 0.96
integration/byval/slices=3 8832020.5 ns 8853854 ns 1.00
integration/byval/reference 1611291 ns 1573521 ns 1.02
integration/byval/slices=2 2583750 ns 2624459 ns 0.98
kernel/indexing 476584 ns 455583 ns 1.05
kernel/indexing_checked 441500 ns 461916 ns 0.96
kernel/launch 10875 ns 10875 ns 1
metal/synchronization/stream 19208 ns 19250 ns 1.00
metal/synchronization/context 19750 ns 19791 ns 1.00
shared array/construct 23756.916666666664 ns 23972.166666666668 ns 0.99
shared array/broadcast 469584 ns 478708 ns 0.98
shared array/random/randn/Float32 1020166 ns 987500 ns 1.03
shared array/random/randn!/Float32 634458 ns 641062.5 ns 0.99
shared array/random/rand!/Int64 572000 ns 576520.5 ns 0.99
shared array/random/rand!/Float32 593208.5 ns 592333.5 ns 1.00
shared array/random/rand/Int64 742792 ns 870458 ns 0.85
shared array/random/rand/Float32 898812.5 ns 935229 ns 0.96
shared array/copyto!/gpu_to_gpu 659667 ns 546667 ns 1.21
shared array/copyto!/cpu_to_gpu 94458 ns 94125 ns 1.00
shared array/copyto!/gpu_to_cpu 84333 ns 84208 ns 1.00
shared array/accumulate/1d 1418250 ns 1434979 ns 0.99
shared array/accumulate/2d 1500167 ns 1497729 ns 1.00
shared array/iteration/findall/int 1939666 ns 1971125 ns 0.98
shared array/iteration/findall/bool 1746333 ns 1777500 ns 0.98
shared array/iteration/findfirst/int 1413458 ns 1410291 ns 1.00
shared array/iteration/findfirst/bool 1374750 ns 1388708 ns 0.99
shared array/iteration/scalar 189167 ns 189562.5 ns 1.00
shared array/iteration/logical 3212770.5 ns 3205291 ns 1.00
shared array/iteration/findmin/1d 1481709 ns 1479229 ns 1.00
shared array/iteration/findmin/2d 1379250 ns 1373083.5 ns 1.00
shared array/reductions/reduce/1d 659583 ns 616666 ns 1.07
shared array/reductions/reduce/2d 706354 ns 716854.5 ns 0.99
shared array/reductions/mapreduce/1d 620667 ns 686417 ns 0.90
shared array/reductions/mapreduce/2d 704958.5 ns 710584 ns 0.99
shared array/permutedims/4d 963438 ns 960250 ns 1.00
shared array/permutedims/2d 939020.5 ns 925458.5 ns 1.01
shared array/permutedims/3d 1003520.5 ns 1015208.5 ns 0.99
shared array/copy 880541 ns 598354.5 ns 1.47

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.