Skip to content

Commit

Permalink
✨ feat(core): GELU vs GELUApprox (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
jean-francoisreboud authored Jan 5, 2024
1 parent 4969db6 commit 096b95d
Show file tree
Hide file tree
Showing 11 changed files with 383 additions and 64 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.

## [unreleased]

⚙️ **core:** GELU vs GELUApprox ([113](https://github.com/owkin/GrAIdient/pull/113))\
🚀 **perf:** QuerySelf & ValueSelf ([112](https://github.com/owkin/GrAIdient/pull/112))\
🚀 **perf:** benchmark ViT base model ([111](https://github.com/owkin/GrAIdient/pull/111))\
🐛 **fix:** run on Apple Silicon ([110](https://github.com/owkin/GrAIdient/pull/110))\
Expand Down
100 changes: 94 additions & 6 deletions Sources/GrAIdient/Core/Function/Activation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -767,23 +767,23 @@ public class Sigmoid: ActivationFunction
}
}

/// GELU activation function.
public class GELU: ActivationFunction
/// GELU approximative activation function.
public class GELUApprox: ActivationFunction
{
public static let str = "GELU"
public static let str = "GELUApprox"

/// Forward GPU kernel.
public override var forwardKernel: String
{
get {
return "forwardGELU"
return "forwardGELUApprox"
}
}
/// Backward GPU kernel.
public override var backwardKernel: String
{
get {
return "backwardGELU"
return "backwardGELUApprox"
}
}

Expand Down Expand Up @@ -865,6 +865,83 @@ public class GELU: ActivationFunction
}
}

/// GELU activation function.
public class GELU: ActivationFunction
{
public static let str = "GELU"

/// Forward GPU kernel.
public override var forwardKernel: String
{
get {
return "forwardGELU"
}
}
/// Backward GPU kernel.
public override var backwardKernel: String
{
get {
return "backwardGELU"
}
}

///
/// Coefficient to apply during the weights initialization.
///
/// - Returns: The coefficient.
///
open override var coeffInitWeights: Float
{
get {
return Float(sqrt(2.0))
}
}

/// Create a GELU activation function.
init()
{
super.init(GELU.str)
}

///
/// Decode from the disk.
///
/// Throw an error if reading from the decoder fails, or
/// if the data read is corrupted or otherwise invalid.
///
/// - Parameter decoder: The decoder to read data from.
///
required public init(from decoder: Decoder) throws
{
try super.init(from: decoder)
}

///
/// Forward CPU.
///
/// - Parameter x: The input.
/// - Returns: The output.
///
public override func apply(_ x: Double) -> Double
{
return 0.5 * x * (1 + erf(x / sqrt(2.0)))
}

///
/// Backward CPU.
///
/// - Parameter x: The input.
/// - Returns: The output.
///
public override func derivate(_ x: Double) -> Double
{
let tmp1 = 0.5 * (1.0 + erf(x / sqrt(2.0)))
let tmp2 = x / sqrt(2.0 * Double.pi) * exp(-x * x / 2.0)
let derivative = tmp1 + tmp2
return derivative
}
}

/// Factory API to build an activation function.
public protocol ActivationKernel
{
Expand All @@ -886,6 +963,7 @@ class ActivationKernelImpl: ActivationKernel
LeakyReLU.str: LeakyReLUKernel(),
SoftReLU.str: SoftReLUKernel(),
Sigmoid.str: SigmoidKernel(),
GELUApprox.str: GELUApproxKernel(),
GELU.str: GELUKernel()
]

Expand Down Expand Up @@ -954,7 +1032,17 @@ private class SigmoidKernel: ActivationKernelImpl
}
}

/// Factory to build a Sigmoid function.
/// Factory to build a GELU approximative function.
private class GELUApproxKernel: ActivationKernelImpl
{
/// Build a Sigmoid function.
override func build() -> ActivationFunction
{
return GELUApprox()
}
}

/// Factory to build a GELU function.
private class GELUKernel: ActivationKernelImpl
{
/// Build a Sigmoid function.
Expand Down
43 changes: 28 additions & 15 deletions Sources/GrAIdient/LayerSeq/ValueSeq.swift
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,16 @@ public class ValueSelfSeq: LayerMergeSeq

if _layersPrev[0].computeDelta
{
if _layersPrev[0].dirty
{
for elem in 0..<batchSize {
for seqK in 0..<sequence {
for depth in 0..<nbNeurons * _nbBlocksPrev
{
value.get(seqK, depth)!.v[elem].delta = 0
}}}
}

for elem in 0..<batchSize {
for head in 0..<_nbHeads {
for seqK in 0..<sequence {
Expand All @@ -1015,18 +1025,9 @@ public class ValueSelfSeq: LayerMergeSeq
sum += deltaCur * scoreTmp
}

if _layersPrev[0].dirty
{
value.get(
seqK, depth + _valueOffset * nbNeurons
)!.v[elem].delta = sum
}
else
{
value.get(
seqK, depth + _valueOffset * nbNeurons
)!.v[elem].delta += sum
}
value.get(
seqK, depth + _valueOffset * nbNeurons
)!.v[elem].delta += sum
}}}}
}
if _layersPrev[1].computeDelta
Expand Down Expand Up @@ -1095,7 +1096,20 @@ public class ValueSelfSeq: LayerMergeSeq
{
try value.checkStateBackwardGPU(batchSize: batchSize)

let pDirty: [UInt32] = value.dirty ? [1] : [0]
if value.dirty
{
let nbElems = value.delta.nbElems
let pNbElems: [UInt32] = [UInt32(nbElems)]

command = MetalKernel.get.createCommand(
"reset", deviceID: deviceID
)
command.setBytes(pNbElems, atIndex: 0)
command.setBuffer(value.delta.metal, atIndex: 1)

command.dispatchThreads(nbElems)
command.enqueue()
}

let kernel = (nbNeurons / _nbHeads) % 4 == 0 ?
"valueSelfValueSeq4Backward" : "valueSelfValueSeqBackward"
Expand All @@ -1112,8 +1126,7 @@ public class ValueSelfSeq: LayerMergeSeq
command.setBytes(pGlobalOffset, atIndex: 6)
command.setBytes(pNbBatch, atIndex: 7)
command.setBytes(pSequence, atIndex: 8)
command.setBytes(pDirty, atIndex: 9)
command.setBuffer(value.delta.metal, atIndex: 10)
command.setBuffer(value.delta.metal, atIndex: 9)

command.dispatchThreads(
width: nbNeurons / coeff,
Expand Down
94 changes: 92 additions & 2 deletions Sources/GrAIdient/Metal/Kernel/Activation.metal
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ kernel void backwardSigmoid(
delta[id] = delta[id] * derivative;
}

kernel void forwardGELU(
kernel void forwardGELUApprox(
constant uint * pNbElems,
device float * tmps,
device float * outs,
Expand Down Expand Up @@ -275,7 +275,7 @@ kernel void forwardGELU(
outs[id] = 0.5 * x * (1 + tmp2);
}

kernel void backwardGELU(
kernel void backwardGELUApprox(
const device float * tmps,
constant uint * pNbElems,
device float * delta,
Expand Down Expand Up @@ -311,3 +311,93 @@ kernel void backwardGELU(
float derivative = 0.5 * (1 + tmp2 + x * tmp3);
delta[id] = delta[id] * derivative;
}

/*
* Approximation to the error function.
* Based on code from:
* https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
*/
float erf(float a)
{
float r, s, t, u;
t = metal::abs(a);
s = a * a;
if (t > 0.927734375f)
{
// maximum error 0.99527 ulp
r = metal::fma(-1.72853470e-5f, t, 3.83197126e-4f); // -0x1.220000p-16,0x1.91cfb2p-12
u = metal::fma(-3.88396438e-3f, t, 2.42546219e-2f); // -0x1.fd1438p-9, 0x1.8d6342p-6
r = metal::fma(r, s, u);
r = metal::fma(r, t, -1.06777877e-1f); // -0x1.b55cb8p-4
r = metal::fma(r, t, -6.34846687e-1f); // -0x1.450aa0p-1
r = metal::fma(r, t, -1.28717512e-1f); // -0x1.079d0cp-3
r = metal::fma(r, t, -t);
// TODO, replace with expm1 when implemented
r = 1.0f - metal::exp(r);
r = metal::copysign(r, a);
}
else
{
// maximum error 0.98929 ulp
r = -5.96761703e-4f; // -0x1.38e000p-11
r = metal::fma(r, s, 4.99119423e-3f); // 0x1.471a58p-8
r = metal::fma(r, s, -2.67681349e-2f); // -0x1.b691b2p-6
r = metal::fma(r, s, 1.12819925e-1f); // 0x1.ce1c44p-4
r = metal::fma(r, s, -3.76125336e-1f); // -0x1.812700p-2
r = metal::fma(r, s, 1.28379166e-1f); // 0x1.06eba8p-3
r = metal::fma(r, a, a);
}
return r;
}

kernel void forwardGELU(
constant uint * pNbElems,
device float * tmps,
device float * outs,
uint id [[ thread_position_in_grid ]])
{
uint nbElems;

if (pNbElems)
{
nbElems = pNbElems[0];
}
else
return ;

if (id >= nbElems)
{
return ;
}

float x = outs[id];
tmps[id] = x;
outs[id] = 0.5 * x * (1 + erf(x / sqrt(2.0)));
}

kernel void backwardGELU(
const device float * tmps,
constant uint * pNbElems,
device float * delta,
uint id [[ thread_position_in_grid ]])
{
uint nbElems;

if (pNbElems)
{
nbElems = pNbElems[0];
}
else
return ;

if (id >= nbElems)
{
return ;
}

float x = tmps[id];
float tmp1 = 0.5 * (1.0 + erf(x / sqrt(2.0)));
float tmp2 = x / sqrt(2.0 * M_PI_F) * exp(-x * x / 2.0);
float derivative = tmp1 + tmp2;
delta[id] = delta[id] * derivative;
}
Loading

0 comments on commit 096b95d

Please sign in to comment.