Skip to content

Commit

Permalink
🚀 perf: copy & generate weights faster (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored May 12, 2024
1 parent 192f994 commit a9d176c
Show file tree
Hide file tree
Showing 22 changed files with 1,038 additions and 218 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

🚀 **perf:** copy & generate weights faster ([119](https://github.com/owkin/GrAIdient/pull/119))\
🚀 **perf:** Convolution2D ([118](https://github.com/owkin/GrAIdient/pull/118))\
🪜 **feat:** LayerCAM2D -> VQGrad2D, LayerCAMSeq -> VQGradSeq ([#117](https://github.com/owkin/GrAIdient/pull/117))\
⚙️ **core:** GELU vs GELUApprox ([113](https://github.com/owkin/GrAIdient/pull/113))\
Expand Down
209 changes: 209 additions & 0 deletions Sources/GrAIdient/Core/Layer/LayerUpdate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
//

import Foundation
import Accelerate

/// Error occuring in an output layer.
public enum LossError: Error
Expand Down Expand Up @@ -288,6 +289,40 @@ extension LayerWeightInit
return weightsList
}

public func generateWeightsList(
buffer: UnsafeMutableBufferPointer<Float>)
{
let nbElems = weightListSize
switch weightInitClass {
case .XavierUniform:
Self.XavierUniform(
nbElems: nbElems,
connectivityIO: connectivityIO,
buffer: buffer
)
case .XavierNormal:
Self.XavierNormal(
nbElems: nbElems,
connectivityIO: connectivityIO,
buffer: buffer
)
case .KaimingUniform:
Self.KaimingUniform(
nbElems: nbElems,
coeff: coeffInitWeights,
connectivityIO: connectivityIO,
buffer: buffer
)
case .KaimingNormal:
Self.KaimingNormal(
nbElems: nbElems,
coeff: coeffInitWeights,
connectivityIO: connectivityIO,
buffer: buffer
)
}
}

///
/// Xavier uniform initialization method.
///
Expand All @@ -309,6 +344,48 @@ extension LayerWeightInit
return values
}

///
/// Xavier uniform initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - connectivityIO: Number of input and output connections.
/// - buffer: The buffer of values.
///
static func XavierUniform(
nbElems: Int,
connectivityIO: (Int, Int),
buffer: UnsafeMutableBufferPointer<Float>)
{
let bound = sqrt(6) / sqrt(Float(connectivityIO.0 + connectivityIO.1))
if #available(macOS 13.0, *)
{
guard
var arrayDescriptor = BNNSNDArrayDescriptor(
data: buffer,
shape: .vector(nbElems)),
let randomNumberGenerator = BNNSCreateRandomGenerator(
BNNSRandomGeneratorMethodAES_CTR,
nil) else
{
fatalError()
}

BNNSRandomFillUniformFloat(
randomNumberGenerator,
&arrayDescriptor,
-bound,
bound
)

BNNSDestroyRandomGenerator(randomNumberGenerator)
}
else
{
fatalError()
}
}

///
/// Xavier normal initialization method.
///
Expand All @@ -330,11 +407,54 @@ extension LayerWeightInit
return values
}

///
/// Xavier normal initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - connectivityIO: Number of input and output connections.
/// - buffer: The buffer of values.
///
static func XavierNormal(
nbElems: Int,
connectivityIO: (Int, Int),
buffer: UnsafeMutableBufferPointer<Float>)
{
let std = sqrt(2) / sqrt(Float(connectivityIO.0 + connectivityIO.1))
if #available(macOS 13.0, *)
{
guard
var arrayDescriptor = BNNSNDArrayDescriptor(
data: buffer,
shape: .vector(nbElems)),
let randomNumberGenerator = BNNSCreateRandomGenerator(
BNNSRandomGeneratorMethodAES_CTR,
nil) else
{
fatalError()
}

BNNSRandomFillNormalFloat(
randomNumberGenerator,
&arrayDescriptor,
0.0,
std
)

BNNSDestroyRandomGenerator(randomNumberGenerator)
}
else
{
fatalError()
}
}

///
/// Kaiming uniform initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - coeff: Multiplicative coefficient.
/// - connectivityIO: Number of input and output connections.
/// - Returns: Weights values.
///
Expand All @@ -352,11 +472,56 @@ extension LayerWeightInit
return values
}

///
/// Kaiming uniform initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - coeff: Multiplicative coefficient.
/// - connectivityIO: Number of input and output connections.
/// - buffer: The buffer of values.
///
static func KaimingUniform(
nbElems: Int,
coeff: Float,
connectivityIO: (Int, Int),
buffer: UnsafeMutableBufferPointer<Float>)
{
let bound = sqrt(3) * coeff / sqrt(Float(connectivityIO.0))
if #available(macOS 13.0, *)
{
guard
var arrayDescriptor = BNNSNDArrayDescriptor(
data: buffer,
shape: .vector(nbElems)),
let randomNumberGenerator = BNNSCreateRandomGenerator(
BNNSRandomGeneratorMethodAES_CTR,
nil) else
{
fatalError()
}

BNNSRandomFillUniformFloat(
randomNumberGenerator,
&arrayDescriptor,
-bound,
bound
)

BNNSDestroyRandomGenerator(randomNumberGenerator)
}
else
{
fatalError()
}
}

///
/// Xavier normal initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - coeff: Multiplicative coefficient.
/// - connectivityIO: Number of input and output connections.
/// - Returns: Weights values.
///
Expand All @@ -373,6 +538,50 @@ extension LayerWeightInit
}
return values
}

///
/// Kaiming normal initialization method.
///
/// - Parameters:
/// - nbElems: Number of weights to initialize.
/// - coeff: Multiplicative coefficient.
/// - connectivityIO: Number of input and output connections.
/// - buffer: The buffer of values.
///
static func KaimingNormal(
nbElems: Int,
coeff: Float,
connectivityIO: (Int, Int),
buffer: UnsafeMutableBufferPointer<Float>)
{
let std = coeff / sqrt(Float(connectivityIO.0))
if #available(macOS 13.0, *)
{
guard
var arrayDescriptor = BNNSNDArrayDescriptor(
data: buffer,
shape: .vector(nbElems)),
let randomNumberGenerator = BNNSCreateRandomGenerator(
BNNSRandomGeneratorMethodAES_CTR,
nil) else
{
fatalError()
}

BNNSRandomFillNormalFloat(
randomNumberGenerator,
&arrayDescriptor,
0.0,
std
)

BNNSDestroyRandomGenerator(randomNumberGenerator)
}
else
{
fatalError()
}
}
}

///
Expand Down
21 changes: 8 additions & 13 deletions Sources/GrAIdient/Layer1D/Constant1D.swift
Original file line number Diff line number Diff line change
Expand Up @@ -259,21 +259,16 @@ public class Constant1D: Layer1D, LayerUpdate
)

let weightsPtr = _wBuffers.w_p!.shared.buffer
if _weightsList.count == 0
{
for depth in 0..<nbNeurons
{
weightsPtr[depth] = 0.0
}
}
else
if _weightsList.count != 0
{
for depth in 0..<nbNeurons
{
weightsPtr[depth] = _weightsList[depth]
}
_weightsList = []
copyFloatArrayToBuffer(
array: &_weightsList,
buffer: weightsPtr,
start: 0,
nbElems: nbNeurons
)
}
_weightsList = []

MetalKernel.get.upload([_wBuffers.w_p!])
_wDeltaWeights = nil
Expand Down
35 changes: 15 additions & 20 deletions Sources/GrAIdient/Layer1D/FullyConnected.swift
Original file line number Diff line number Diff line change
Expand Up @@ -567,12 +567,6 @@ public class FullyConnected: Activation1D, LayerWithActivation, LayerWeightInit
///
public func initWeightsGPU()
{
if _weightsList.count == 0
{
_weightsList = generateWeightsList()
_weightsList += [Float](repeating: 0.0, count: weightHeight)
}

_wBuffers = WeightBuffers(
nbElems: weightHeight * weightWidth,
deviceID: deviceID
Expand All @@ -585,25 +579,26 @@ public class FullyConnected: Activation1D, LayerWithActivation, LayerWeightInit
let weightsPtr = _wBuffers.w_p!.shared.buffer
let biasesPtr = _bBuffers.w_p!.shared.buffer

for elem in 0..<weightHeight * weightWidth
{
weightsPtr[elem] = _weightsList[elem]
}

// In both cases, biases may have been set by caller or by ourselves.
if _updateBiases
if _weightsList.count == 0
{
let offset = weightHeight * weightWidth
for depth in 0..<weightHeight
{
biasesPtr[depth] = _weightsList[offset + depth]
}
generateWeightsList(buffer: weightsPtr)
}
else
{
for depth in 0..<weightHeight
copyFloatArrayToBuffer(
array: &_weightsList,
buffer: weightsPtr,
start: 0,
nbElems: weightHeight * weightWidth
)
if _updateBiases
{
biasesPtr[depth] = 0.0
copyFloatArrayToBuffer(
array: &_weightsList,
buffer: biasesPtr,
start: weightHeight * weightWidth,
nbElems: weightHeight
)
}
}
_weightsList = []
Expand Down
21 changes: 8 additions & 13 deletions Sources/GrAIdient/Layer2D/Constant2D.swift
Original file line number Diff line number Diff line change
Expand Up @@ -316,21 +316,16 @@ public class Constant2D: Layer2D, LayerResize, LayerUpdate
)

let weightsPtr = _wBuffers.w_p!.shared.buffer
if _weightsList.count == 0
{
for depth in 0..<nbChannels
{
weightsPtr[depth] = 0.0
}
}
else
if _weightsList.count != 0
{
for depth in 0..<nbChannels
{
weightsPtr[depth] = _weightsList[depth]
}
_weightsList = []
copyFloatArrayToBuffer(
array: &_weightsList,
buffer: weightsPtr,
start: 0,
nbElems: nbChannels
)
}
_weightsList = []

MetalKernel.get.upload([_wBuffers.w_p!])
_wDeltaWeights = nil
Expand Down
Loading

0 comments on commit a9d176c

Please sign in to comment.