Skip to content

Commit

Permalink
sync pos. (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 5, 2024
1 parent b2f3450 commit d5e8f89
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 46 deletions.
4 changes: 2 additions & 2 deletions src/grok1-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ void grokMulInput(TASK_ARGS) {
}

// source: https://github.com/karpathy/llama2.c/pull/408
void ropeFalcon(float* q, float* k, TransformerSpec* spec, int pos, float theta) {
void ropeFalcon(float* q, float* k, TransformerSpec* spec, pos_t pos, float theta) {
for (int i = 0; i < spec->nHeads; i++) {
for (int j = 0; j < spec->headSize / 2; j++) {
float freq = 1.0f / powf(theta, 2.0f * (float)j / (float)spec->headSize);
Expand Down Expand Up @@ -301,7 +301,7 @@ TransformerArch buildGrok1Arch(TransformerSpec* spec) {

// inference

a.I(sendPoke, TASK_TYPE_TRANSFER);
a.I(sendPos, TASK_TYPE_TRANSFER);
a.I(grokMulInput, TASK_TYPE_INFERENCE);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);
Expand Down
2 changes: 1 addition & 1 deletion src/llama2-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ TransformerArch buildLlama2Arch(TransformerSpec* spec) {

// inference

a.I(sendPoke, TASK_TYPE_TRANSFER);
a.I(sendPos, TASK_TYPE_TRANSFER);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
Expand Down
8 changes: 4 additions & 4 deletions src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void generate(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer
long start = 0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = promptTokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
pos_t pos = 0; // position in the sequence

unsigned long inferenceTime;
unsigned long transferTime;
Expand Down Expand Up @@ -139,7 +139,7 @@ void chat(Inference* inference, SocketPool* socketPool, Tokenizer* tokenizer, Sa
int next; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
int prev_token;
int pos = 0; // position in the sequence
pos_t pos = 0; // position in the sequence
while (pos < args->steps) {
// when it is the user's turn to contribute tokens to the dialog...
if (userTurn) {
Expand Down Expand Up @@ -236,9 +236,9 @@ void simpleServer(Inference* inference, SocketPool* socketPool, Tokenizer *token
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);

int token = promptTokens[0];
int maxPos = nPromptTokens + maxTokens;
pos_t maxPos = nPromptTokens + maxTokens;
if (maxPos > spec->seqLen) maxPos = spec->seqLen;
for (int pos = 0; pos < maxPos; pos++) {
for (pos_t pos = 0; pos < maxPos; pos++) {
float* logits = inference->infer(token, pos);

if (pos < nPromptTokens - 1) {
Expand Down
2 changes: 1 addition & 1 deletion src/mixtral-tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ TransformerArch buildMixtralArch(TransformerSpec* spec) {

// inference

a.I(sendPoke, TASK_TYPE_TRANSFER);
a.I(sendPos, TASK_TYPE_TRANSFER);
for (int i = 0; i < spec->nLayers; i++) {
a.I(llamaRmsAtt, TASK_TYPE_INFERENCE);
a.I(llamaRmsAttNorm, TASK_TYPE_INFERENCE);
Expand Down
12 changes: 6 additions & 6 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ static inline void setNoDelay(int socket) {
throw std::runtime_error("Error setting socket to no-delay");
}

static inline void writeSocket(int socket, const char* data, size_t size) {
static inline void writeSocket(int socket, const void* data, size_t size) {
while (size > 0) {
int s = send(socket, (char*)data, size, 0);
if (s < 0) {
Expand All @@ -50,7 +50,7 @@ static inline void writeSocket(int socket, const char* data, size_t size) {
}
}

static inline void readSocket(bool* isNonBlocking, int socket, char* data, size_t size) {
static inline void readSocket(bool* isNonBlocking, int socket, void* data, size_t size) {
unsigned int attempt = 0;
time_t startTime;
while (size > 0) {
Expand Down Expand Up @@ -136,13 +136,13 @@ SocketPool::~SocketPool() {
delete[] isNonBlocking;
}

void SocketPool::write(unsigned int socketIndex, const char* data, size_t size) {
void SocketPool::write(unsigned int socketIndex, const void* data, size_t size) {
assert(socketIndex >= 0 && socketIndex < nSockets);
sentBytes += size;
writeSocket(sockets[socketIndex], data, size);
}

void SocketPool::read(unsigned int socketIndex, char* data, size_t size) {
void SocketPool::read(unsigned int socketIndex, void* data, size_t size) {
assert(socketIndex >= 0 && socketIndex < nSockets);
recvBytes += size;
readSocket(&isNonBlocking[socketIndex], sockets[socketIndex], data, size);
Expand Down Expand Up @@ -236,11 +236,11 @@ Socket::~Socket() {
close(socket);
}

void Socket::write(const char* data, size_t size) {
void Socket::write(const void* data, size_t size) {
writeSocket(socket, data, size);
}

void Socket::read(char* data, size_t size) {
void Socket::read(void* data, size_t size) {
readSocket(&isNonBlocking, socket, data, size);
}

Expand Down
10 changes: 5 additions & 5 deletions src/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class WriteSocketException : public std::exception {

struct SocketIo {
unsigned int socketIndex;
const char* data;
const void* data;
size_t size;
};

Expand All @@ -39,8 +39,8 @@ class SocketPool {
SocketPool(unsigned int nSockets, int* sockets);
~SocketPool();

void write(unsigned int socketIndex, const char* data, size_t size);
void read(unsigned int socketIndex, char* data, size_t size);
void write(unsigned int socketIndex, const void* data, size_t size);
void read(unsigned int socketIndex, void* data, size_t size);
void writeMany(unsigned int n, SocketIo* ios);
void readMany(unsigned int n, SocketIo* ios);
void getStats(size_t* sentBytes, size_t* recvBytes);
Expand All @@ -55,8 +55,8 @@ class Socket {
Socket(int socket);
~Socket();

void write(const char* data, size_t size);
void read(char* data, size_t size);
void write(const void* data, size_t size);
void read(void* data, size_t size);
};

class SocketServer {
Expand Down
20 changes: 8 additions & 12 deletions src/tasks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,27 +160,23 @@ void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, Tra
}
}

void sendPoke(TASK_ARGS) {
void sendPos(TASK_ARGS) {
TASK_VARIABLES;

if (ctx->socketPool != NULL) {
const char poke = 0x25;

unsigned int nSockets = ctx->socketPool->nSockets / nThreads + (ctx->socketPool->nSockets % nThreads > threadIndex ? 1 : 0);
SocketIo ios[nSockets];
for (int i = 0; i < nSockets; i++) {
ios[i].socketIndex = threadIndex + i * nThreads;
ios[i].data = &poke;
ios[i].size = sizeof(char);
ios[i].data = &transformer->pos;
ios[i].size = sizeof(pos_t);
}
ctx->socketPool->writeMany(nSockets, ios);
}
}

void waitForPoke(Socket* socket) {
char poke;
socket->read(&poke, sizeof(char));
assert(poke == 0x25);
void waitForPos(Transformer* transformer, Socket* socket) {
socket->read(&transformer->pos, sizeof(pos_t));
}

Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool) {
Expand All @@ -190,15 +186,15 @@ Inference::Inference(TransformerArch* arch, unsigned int nThreads, Transformer*
context.transformer = transformer;
context.socket = NULL;
context.socketPool = socketPool;
assert(arch->inference.tasks[0].handler == sendPoke);
assert(arch->inference.tasks[0].handler == sendPos);
taskLoop = new TaskLoop(nThreads, arch->inference.nTasks, TASK_N_TYPES, arch->inference.tasks, (void*)&context);
}

Inference::~Inference() {
delete taskLoop;
}

float* Inference::infer(int token, int pos) {
float* Inference::infer(int token, pos_t pos) {
transformer->pos = pos;

float* contentRow = ((float*)transformer->tokenEmbeddingTable) + token * transformer->spec->dim;
Expand Down Expand Up @@ -231,7 +227,7 @@ Worker::~Worker() {

void Worker::work() {
while (true) {
waitForPoke(socket);
waitForPos(transformer, socket);

context.currentBlockIndex = 0;
taskLoop->run();
Expand Down
4 changes: 2 additions & 2 deletions src/tasks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void syncMissingSlicesOfSlicedBuffer(unsigned int nThreads, unsigned int threadI
void quantizeUnitBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void quantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool quantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void dequantizeSlicedBuffer(unsigned int nThreads, unsigned int threadIndex, TransformerContext* ctx, bool dequantizeRootSlice, uint8_t sourceBufferIndex, uint8_t targetBufferIndex);
void sendPoke(TASK_ARGS);
void sendPos(TASK_ARGS);

class Inference {
private:
Expand All @@ -60,7 +60,7 @@ class Inference {
public:
Inference(TransformerArch* arch, unsigned int nThreads, Transformer* transformer, SocketPool* socketPool);
~Inference();
float* infer(int token, int pos);
float* infer(int token, pos_t pos);
void getStats(unsigned long* inferenceTime, unsigned long* transferTime);
};

Expand Down
12 changes: 2 additions & 10 deletions src/transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,8 @@ size_t MatmulSlice::splitWeights(uint8_t sliceIndex, char* weights, char* weight
return copiedBytes;
}

long MatmulSlice::mergeOutputs(uint8_t sliceIndex, float* output, float* output0) {
long offset = this->d0 * sliceIndex;
for (int i = 0; i < this->d0; i++) {
output[offset + i] = output0[i];
}
return offset; // offset in floats
}

void initRope(float* cache, TransformerSpec* spec) {
for (int pos = 0; pos < spec->seqLen; pos++) {
for (pos_t pos = 0; pos < spec->seqLen; pos++) {
for (int i = 0; i < spec->dim; i += 2) {
int head_dim = i % spec->headSize;
float freq = 1.0f / powf(spec->ropeTheta, head_dim / (float)spec->headSize);
Expand All @@ -65,7 +57,7 @@ void initRope(float* cache, TransformerSpec* spec) {
}
}

void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex) {
void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex) {
int halfDim = spec->dim / 2;
int slice = halfDim / nThreads;
int iStart = threadIndex * slice;
Expand Down
7 changes: 4 additions & 3 deletions src/transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include "quants.hpp"
#include "socket.hpp"

typedef unsigned short pos_t;

class MatmulSlice {
public:
FloatType type;
Expand All @@ -17,7 +19,6 @@ class MatmulSlice {

MatmulSlice(FloatType type, int nSlices, int n, int d);
size_t splitWeights(uint8_t sliceIndex, char* weights, char* weights0);
long mergeOutputs(uint8_t sliceIndex, float* output, float* output0);
};

enum TransformerHeaderKey {
Expand Down Expand Up @@ -84,7 +85,7 @@ struct TransformerSpec {
};

void initRope(float* cache, TransformerSpec* spec);
void rope(float* cache, float* q, float* k, TransformerSpec* spec, int pos, unsigned int nThreads, unsigned int threadIndex);
void rope(float* cache, float* q, float* k, TransformerSpec* spec, pos_t pos, unsigned int nThreads, unsigned int threadIndex);

class TransformerBlock {
public:
Expand Down Expand Up @@ -185,8 +186,8 @@ class Transformer {
size_t wclsBytes;
char* wcls;

pos_t pos;
float rms;
int pos;
float* x;
float* logits;
float* ropeCache;
Expand Down

0 comments on commit d5e8f89

Please sign in to comment.