Skip to content

Commit

Permalink
fix: chunked stream, close stream without econnreset. (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
b4rtaz authored May 26, 2024
1 parent 4575818 commit ef1e312
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 83 deletions.
176 changes: 95 additions & 81 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ class HttpParser {
// Parse body
std::getline(iss, httpRequest.body, '\0');

// Parse JSON if Content-Type is application/json
httpRequest.parsedJson = json::object();
if(httpRequest.headers.find("Content-Type") != httpRequest.headers.end()){
if(httpRequest.headers["Content-Type"] == "application/json"){
httpRequest.parsedJson = json::parse(httpRequest.body);
}
if (httpRequest.body.size() > 0) {
// printf("body: %s\n", httpRequest.body.c_str());
httpRequest.parsedJson = json::parse(httpRequest.body);
}

return httpRequest;
}
private:
Expand Down Expand Up @@ -189,6 +185,14 @@ void to_json(json& j, const Choice& choice) {
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
}

std::string createJsonResponse(std::string json) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
return oss.str();
}

struct ChatCompletionChunk {
std::string id;
std::string object;
Expand All @@ -197,7 +201,7 @@ struct ChatCompletionChunk {
std::vector<ChunkChoice> choices;

ChatCompletionChunk(ChunkChoice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") {
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
Expand Down Expand Up @@ -231,7 +235,7 @@ struct ChatCompletion {
ChatUsage usage;

ChatCompletion(Choice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") {
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
Expand Down Expand Up @@ -288,90 +292,93 @@ std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage>
return oss.str();
}

void outputChatCompletionChunk(Socket &client_socket, const std::string &delta, const std::string &finish_reason){
ChunkChoice choice;

if(finish_reason.size() > 0){
choice.finish_reason = finish_reason;
void writeChunk(Socket& socket, const std::string data, const bool stop) {
std::ostringstream formattedChunk;
formattedChunk << std::hex << data.size() << "\r\n" << data << "\r\n";
if (stop) {
formattedChunk << "0000\r\n\r\n";
}
else{
socket.write(formattedChunk.str().c_str(), formattedChunk.str().size());
}

void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, const bool stop){
ChunkChoice choice;
if (stop) {
choice.finish_reason = "stop";
} else {
choice.delta = ChatMessageDelta("assistant", delta);
}

ChatCompletionChunk chunk = ChatCompletionChunk(choice);

std::ostringstream oss;

oss << "data: " << ((json)chunk).dump() << "\n\n";

if(finish_reason.size() > 0){
oss << "data: [DONE]\n\n";
}

std::string chunkResponse = oss.str();

// Format the chunked response
std::ostringstream formattedChunk;
formattedChunk << std::hex << chunkResponse.length() << "\r\n" << chunkResponse << "\r\n";
std::ostringstream buffer;
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
writeChunk(client_socket, buffer.str(), false);

client_socket.write(formattedChunk.str().c_str(), formattedChunk.str().length());
if (stop) {
writeChunk(client_socket, "data: [DONE]", true);
}
}

void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
printf("Handling Completion Request\n");
void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
// Set inference arguments
InferenceParams inferParams;
inferParams.temperature = args->temperature;
inferParams.top_p = args->topp;
inferParams.seed = args->seed;
inferParams.stream = false;
inferParams.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
inferParams.max_tokens = spec->seqLen - inferParams.prompt.size();

if(request.parsedJson.contains("stream")){
inferParams.stream = request.parsedJson["stream"].template get<bool>();
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
params.max_tokens = spec->seqLen - params.prompt.size();

if (request.parsedJson.contains("stream")) {
params.stream = request.parsedJson["stream"].get<bool>();
}
if(request.parsedJson.contains("temperature")){
inferParams.temperature = request.parsedJson["temperature"].template get<float>();
assert(inferParams.temperature >= 0.0f);
sampler->setTemp(inferParams.temperature);
if (request.parsedJson.contains("temperature")) {
params.temperature = request.parsedJson["temperature"].template get<float>();
assert(params.temperature >= 0.0f);
sampler->setTemp(params.temperature);
}
if(request.parsedJson.contains("seed")){
inferParams.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(inferParams.seed);
if (request.parsedJson.contains("seed")) {
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(params.seed);
}
if(request.parsedJson.contains("max_tokens")){
inferParams.max_tokens = request.parsedJson["max_tokens"].template get<int>();
assert(inferParams.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented
if (request.parsedJson.contains("max_tokens")) {
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
assert(params.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented
}
if(request.parsedJson.contains("stop")){
inferParams.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
if (request.parsedJson.contains("stop")) {
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
} else {
const std::string defaultStop = "<|eot_id|>";
params.stop = std::vector<std::string>{defaultStop};
}

printf("🔸");
fflush(stdout);

//Process the chat completion request
std::vector<std::string> generated;
generated.get_allocator().allocate(inferParams.max_tokens);
generated.get_allocator().allocate(params.max_tokens);

if (inferParams.stream) {
if (params.stream) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: keep-alive\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";

client_socket.write(oss.str().c_str(), oss.str().length());
socket.write(oss.str().c_str(), oss.str().length());
}

int promptLength = inferParams.prompt.length();
int promptLength = params.prompt.length();
int nPromptTokens;
int promptTokens[promptLength + 3];
char prompt[promptLength + 1];
prompt[promptLength] = 0;
strcpy(prompt, inferParams.prompt.c_str());
strcpy(prompt, params.prompt.c_str());
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);

int token = promptTokens[0];
pos_t maxPos = nPromptTokens + inferParams.max_tokens;
pos_t maxPos = nPromptTokens + params.max_tokens;
if (maxPos > spec->seqLen) maxPos = spec->seqLen;
bool eosEncountered = false;
for (pos_t pos = 0; pos < maxPos; pos++) {
Expand All @@ -390,15 +397,15 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer

bool safePiece = isSafePiece(piece);

if (!inferParams.stop.empty() && safePiece) {
if (!params.stop.empty() && safePiece) {
std::string concatenatedTokens;
int startIndex = std::max(0, static_cast<int>(generated.size()) - 7);
for (int i = startIndex; i < generated.size(); ++i) {
concatenatedTokens += generated[i];
}
concatenatedTokens += std::string(piece);

for (const auto& word : inferParams.stop) {
for (const auto& word : params.stop) {
if (concatenatedTokens.find(word) != std::string::npos) {
eosEncountered = true;
break;
Expand All @@ -409,49 +416,56 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer
if (eosEncountered) break;

std::string string = std::string(piece);

//char string[100];
//strcpy(string, piece);
safePrintf(piece);
fflush(stdout);

generated.push_back(string);

if (inferParams.stream) {
outputChatCompletionChunk(client_socket, string, "");
if (params.stream) {
writeChatCompletionChunk(socket, string, false);
}
}
}

if (!inferParams.stream) {
if (!params.stream) {
ChatMessage chatMessage = ChatMessage("assistant", std::accumulate(generated.begin(), generated.end(), std::string("")));
Choice responseChoice = Choice(chatMessage);
ChatCompletion completion = ChatCompletion(responseChoice);
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());

std::string chatJson = ((json)completion).dump();

std::ostringstream oss;

oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << chatJson.length() << "\r\n\r\n" << chatJson;

std::string response = oss.str();

client_socket.write(response.c_str(), response.length());
std::string response = createJsonResponse(chatJson);
socket.write(response.c_str(), response.length());
} else {
outputChatCompletionChunk(client_socket, "", "stop");
writeChatCompletionChunk(socket, "", true);
}
printf("🔶\n");
fflush(stdout);
}

void handleModelsRequest(Socket& client_socket, HttpRequest& request) {
std::string response = createJsonResponse(
"{ \"object\": \"list\","
"\"data\": ["
"{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }"
"] }");
client_socket.write(response.c_str(), response.length());
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
SocketServer* server = new SocketServer(args->port);
printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port);

std::vector<Route> routes = {
{
"/v1/chat/completions",
HttpMethod::METHOD_POST,
std::bind(&handleCompletionsRequest, std::placeholders::_1, std::placeholders::_2, inference, tokenizer, sampler, args, spec)
},
{
"/v1/models",
HttpMethod::METHOD_GET,
std::bind(&handleModelsRequest, std::placeholders::_1, std::placeholders::_2)
}
};

Expand All @@ -464,7 +478,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
// Parse the HTTP request
HttpRequest request = HttpParser::parseRequest(std::string(httpRequest.begin(), httpRequest.end()));
// Handle the HTTP request
printf("New Request: %s %s\n", request.getMethod().c_str(), request.path.c_str());
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
Router::routeRequest(client, request, routes);
} catch (ReadSocketException& ex) {
printf("Read socket error: %d %s\n", ex.code, ex.message);
Expand Down
3 changes: 1 addition & 2 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ Socket SocketServer::accept() {
throw std::runtime_error("Error accepting connection");
setNoDelay(clientSocket);
setQuickAck(clientSocket);
printf("Client connected\n");
return Socket(clientSocket);
}

Expand Down Expand Up @@ -261,7 +260,7 @@ bool Socket::tryRead(void* data, size_t size, unsigned long maxAttempts) {

std::vector<char> Socket::readHttpRequest() {
std::vector<char> httpRequest;
char buffer[1024]; // Initial buffer size
char buffer[1024 * 1024]; // TODO: this should be refactored asap
ssize_t bytesRead;

// Peek into the socket buffer to check available data
Expand Down

0 comments on commit ef1e312

Please sign in to comment.