Skip to content

Commit

Permalink
✨ feat(layer_seq): LLM sliding window (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Jul 19, 2024
1 parent 723b021 commit 54b4a30
Show file tree
Hide file tree
Showing 8 changed files with 764 additions and 76 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ All notable changes to this project will be documented in this file.

## [unreleased]

🚀 **examples**: 3 LLMs examples ([#130](https://github.com/owkin/GrAIdient/pull/130))\
**layer_seq:** LLM sliding window ([#131](https://github.com/owkin/GrAIdient/pull/131))\
🚀 **examples:** 3 LLMs examples ([#130](https://github.com/owkin/GrAIdient/pull/130))\
📚 **docs:** LLM doc & split tests ([129](https://github.com/owkin/GrAIdient/pull/129))\
**layer_seq:** LLM generate ([128](https://github.com/owkin/GrAIdient/pull/128))\
**layer_seq:** MultiplySeq, SiLU & LLM test ([127](https://github.com/owkin/GrAIdient/pull/127))\
Expand Down
150 changes: 116 additions & 34 deletions Sources/GrAIdient/LayerSeq/QuerySeq.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1270,20 +1270,25 @@ public class QueryCausalSeq: LayerMergeSeq
if cacheKey != nil && cacheSeq != nil &&
cacheKey.nbElems != batchSize * cacheSeqMax * nbNeuronsPrevKey
{
_cacheKeyTmp = FloatBuffer(
let cacheKeyTmp = FloatBuffer(
nbElems: batchSize * cacheSeqMax * nbNeuronsPrevKey,
deviceID: deviceID
)

let nbElems = batchSize * cacheSeq * nbNeuronsPrevKey
_copyGPU(nbElems: nbElems, from: cacheKey, to: _cacheKeyTmp)
_copyGPU(nbElems: nbElems, from: cacheKey, to: cacheKeyTmp)

cacheKey = FloatBuffer(
nbElems: batchSize * cacheSeqMax * nbNeuronsPrevKey,
deviceID: deviceID
)

_copyGPU(nbElems: nbElems, from: _cacheKeyTmp, to: cacheKey)
_copyGPU(nbElems: nbElems, from: cacheKeyTmp, to: cacheKey)

if batchSize > 1
{
_cacheKeyTmp = cacheKeyTmp
}
}
}

Expand Down Expand Up @@ -1664,29 +1669,29 @@ public class QueryCausalSeq: LayerMergeSeq
throw LayerError.Init(message: "`sequence` should be 1.")
}

_concatGPU()
_mergeCacheGPU()

let query = layersPrev[0] as! LayerSeq
let key = layersPrev[1] as! LayerSeq
let nbNeuronsPrevQuery = query.nbNeurons
let nbNeuronsPrevKey = key.nbNeurons
let nbNeurons = (cacheSeq + 1) * _nbHeadsQuery
let nbNeurons = min(cacheSeq + 1, cacheSeqMax) * _nbHeadsQuery

let pNbHeadsQuery: [UInt32] = [UInt32(_nbHeadsQuery)]
let pNbHeadsKey: [UInt32] = [UInt32(_nbHeadsKey)]
let pNbNeurons: [UInt32] = [UInt32(nbNeurons)]
let pNbNeuronsPrevQuery: [UInt32] = [UInt32(nbNeuronsPrevQuery)]
let pNbNeuronsPrevKey: [UInt32] = [UInt32(nbNeuronsPrevKey)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pSequence: [UInt32] = [UInt32(cacheSeq + 1)]
let pSequence: [UInt32] = [UInt32(min(cacheSeq + 1, cacheSeqMax))]

let kernel = (nbNeuronsPrevQuery / _nbHeadsQuery) % 4 == 0 ?
"queryCausalSeq4Generate" : "queryCausalSeqGenerate"
let command = MetalKernel.get.createCommand(
kernel, deviceID: deviceID
)
command.setBuffer(query.outs.metal, atIndex: 0)
command.setBuffer(_cacheKeyTmp.metal, atIndex: 1)
command.setBuffer(_getKeyCacheOutputGPU()!.metal, atIndex: 1)
command.setBytes(pNbHeadsQuery, atIndex: 2)
command.setBytes(pNbHeadsKey, atIndex: 3)
command.setBytes(pNbNeurons, atIndex: 4)
Expand All @@ -1702,55 +1707,71 @@ public class QueryCausalSeq: LayerMergeSeq
)
command.enqueue()

let nbElems = batchSize * (cacheSeq + 1) * nbNeuronsPrevKey
_copyGPU(nbElems: nbElems, from: _cacheKeyTmp, to: cacheKey)

cacheSeq += 1
}

/// Concatenate cache to key.
private func _concatGPU()
/// Merge cache to key.
private func _mergeCacheGPU()
{
let slidingWindow: Bool
if cacheSeq >= cacheSeqMax
{
slidingWindow = true
}
else
{
slidingWindow = false
}

let key = layersPrev[1] as! LayerSeq
let nbNeuronsPrevKey = key.nbNeurons
let nbNeurons = nbNeuronsPrevKey

let pNbNeurons: [UInt32] = [UInt32(nbNeurons)]
let pNbBatch: [UInt32] = [UInt32(batchSize)]
let pSequence: [UInt32] = [UInt32(cacheSeq + 1)]
let pSequence: [UInt32] = [UInt32(min(cacheSeq + 1, cacheSeqMax))]
let pSequenceCache: [UInt32] = [UInt32(cacheSeq)]
let pSequenceKey: [UInt32] = [UInt32(1)]

let metalKernel = MetalKernel.get
var command: MetalCommand

var globalOffset = 0

var pGlobalOffset: [UInt32] = [UInt32(globalOffset)]

let kernel = nbNeurons % 4 == 0 ?
"concat1Seq4Forward" : "concat1SeqForward"
let coeff = nbNeurons % 4 == 0 ? 4 : 1
command = metalKernel.createCommand(
kernel, deviceID: deviceID
)
command.setBuffer(cacheKey.metal, atIndex: 0)
command.setBytes(pGlobalOffset, atIndex: 1)
command.setBytes(pNbNeurons, atIndex: 2)
command.setBytes(pNbBatch, atIndex: 3)
command.setBytes(pSequence, atIndex: 4)
command.setBytes(pSequenceCache, atIndex: 5)
command.setBuffer(_cacheKeyTmp.metal, atIndex: 6)

command.dispatchThreads(
width: nbNeurons / coeff,
height: batchSize * cacheSeq
)
command.enqueue()
if batchSize != 1 && !slidingWindow
{
let pGlobalOffset: [UInt32] = [UInt32(globalOffset)]

command = metalKernel.createCommand(
kernel, deviceID: deviceID
)
command.setBuffer(_getKeyCacheInputGPU()!.metal, atIndex: 0)
command.setBytes(pGlobalOffset, atIndex: 1)
command.setBytes(pNbNeurons, atIndex: 2)
command.setBytes(pNbBatch, atIndex: 3)
command.setBytes(pSequence, atIndex: 4)
command.setBytes(pSequenceCache, atIndex: 5)
command.setBuffer(_getKeyCacheOutputGPU()!.metal, atIndex: 6)

command.dispatchThreads(
width: nbNeurons / coeff,
height: batchSize * cacheSeq
)
command.enqueue()
}

globalOffset += cacheSeq
globalOffset += cacheSeq % cacheSeqMax
// TODO: when using sliding window with an instruct model,
// it is risky to erase the header information!
// if cacheSeq >= cacheSeqMax
// {
// globalOffset += 5
// }

pGlobalOffset = [UInt32(globalOffset)]
let pGlobalOffset = [UInt32(globalOffset)]

command = metalKernel.createCommand(
kernel, deviceID: deviceID
Expand All @@ -1761,7 +1782,7 @@ public class QueryCausalSeq: LayerMergeSeq
command.setBytes(pNbBatch, atIndex: 3)
command.setBytes(pSequence, atIndex: 4)
command.setBytes(pSequenceKey, atIndex: 5)
command.setBuffer(_cacheKeyTmp.metal, atIndex: 6)
command.setBuffer(_getKeyCacheOutputGPU()!.metal, atIndex: 6)

command.dispatchThreads(
width: nbNeurons / coeff,
Expand All @@ -1770,6 +1791,67 @@ public class QueryCausalSeq: LayerMergeSeq
command.enqueue()
}

///
/// Get key cache buffer to use as input in Metal kernel.
///
/// - Returns: key cache to use as input.
///
private func _getKeyCacheInputGPU() -> FloatBuffer?
{
if cacheSeq != nil
{
if cacheSeq % 2 == 0
{
return _cacheKeyTmp
}
else
{
return cacheKey
}
}
return nil
}

///
/// Get key cache buffer to use as input in Metal kernel.
///
/// - Returns: key cache to use as input.
///
private func _getKeyCacheOutputGPU() -> FloatBuffer?
{
if cacheSeq != nil
{
if batchSize == 1
{
return cacheKey
}
else
{
if cacheSeq >= cacheSeqMax // sliding window
{
// The cache key has not changed.
if (cacheSeqMax - 1) % 2 == 0
{
return cacheKey
}
else
{
return _cacheKeyTmp
}
}
else if cacheSeq % 2 == 0
{
return cacheKey
}
else
{
return _cacheKeyTmp
}
}
}
return nil
}

/// Apply the forward pass in the GPU execution context.
private func _forwardGPU()
{
Expand Down
10 changes: 9 additions & 1 deletion Sources/GrAIdient/LayerSeq/SoftmaxSeq.swift
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,15 @@ public class SoftmaxSeq: LayerSeq
///
public class SoftmaxCausalSeq: SoftmaxSeq
{
/// Maximal sequence of cache.
public var cacheSeqMax = 128

/// Current cache sequence.
public var cacheSeq: Int! = nil

private enum Keys: String, CodingKey
{
case cacheSeqMax
case cacheSeq
}

Expand Down Expand Up @@ -401,6 +405,7 @@ public class SoftmaxCausalSeq: SoftmaxSeq
public required init(from decoder: Decoder) throws
{
let values = try decoder.container(keyedBy: Keys.self)
cacheSeqMax = try values.decode(Int.self, forKey: Keys.cacheSeqMax)
cacheSeq = try values.decodeIfPresent(Int.self, forKey: .cacheSeq)
try super.init(from: decoder)
}
Expand All @@ -419,6 +424,7 @@ public class SoftmaxCausalSeq: SoftmaxSeq
public override func encode(to encoder: Encoder) throws
{
var container = encoder.container(keyedBy: Keys.self)
try container.encode(cacheSeqMax, forKey: Keys.cacheSeqMax)
if cacheSeq != nil
{
try container.encode(cacheSeq, forKey: Keys.cacheSeq)
Expand Down Expand Up @@ -453,6 +459,8 @@ public class SoftmaxCausalSeq: SoftmaxSeq
nbHeads: _nbHeads,
params: params
)

layer.cacheSeqMax = cacheSeqMax
layer.cacheSeq = cacheSeq

return layer
Expand Down Expand Up @@ -507,7 +515,7 @@ public class SoftmaxCausalSeq: SoftmaxSeq

if let layerPrev = self.layerPrev as? LayerSeq
{
let nbNeurons = (cacheSeq + 1) * _nbHeads
let nbNeurons = min(cacheSeq + 1, cacheSeqMax) * _nbHeads

let pNbHeads: [UInt32] = [UInt32(_nbHeads)]
let pNbNeurons: [UInt32] = [UInt32(nbNeurons)]
Expand Down
Loading

0 comments on commit 54b4a30

Please sign in to comment.