Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: chunked stream, close stream without econnreset. #65

Merged
merged 1 commit into from
May 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 95 additions & 81 deletions src/apps/dllama-api/dllama-api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ class HttpParser {
// Parse body
std::getline(iss, httpRequest.body, '\0');

// Parse JSON if Content-Type is application/json
httpRequest.parsedJson = json::object();
if(httpRequest.headers.find("Content-Type") != httpRequest.headers.end()){
if(httpRequest.headers["Content-Type"] == "application/json"){
httpRequest.parsedJson = json::parse(httpRequest.body);
}
if (httpRequest.body.size() > 0) {
// printf("body: %s\n", httpRequest.body.c_str());
httpRequest.parsedJson = json::parse(httpRequest.body);
}

return httpRequest;
}
private:
Expand Down Expand Up @@ -189,6 +185,14 @@ void to_json(json& j, const Choice& choice) {
j = json{{"index", choice.index}, {"message", choice.message}, {"finish_reason", choice.finish_reason}};
}

std::string createJsonResponse(std::string json) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << json.length() << "\r\n\r\n" << json;
return oss.str();
}

struct ChatCompletionChunk {
std::string id;
std::string object;
Expand All @@ -197,7 +201,7 @@ struct ChatCompletionChunk {
std::vector<ChunkChoice> choices;

ChatCompletionChunk(ChunkChoice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") {
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
Expand Down Expand Up @@ -231,7 +235,7 @@ struct ChatCompletion {
ChatUsage usage;

ChatCompletion(Choice &choice_)
: id("chatcmpl-test"), object("chat.completion"), model("Distributed Model") {
: id("chatcmpl-test"), object("chat.completion"), model("dl") {
created = std::time(nullptr); // Set created to current Unix timestamp
choices.push_back(choice_);
}
Expand Down Expand Up @@ -288,90 +292,93 @@ std::string buildChatPrompt(Tokenizer *tokenizer, const std::vector<ChatMessage>
return oss.str();
}

void outputChatCompletionChunk(Socket &client_socket, const std::string &delta, const std::string &finish_reason){
ChunkChoice choice;

if(finish_reason.size() > 0){
choice.finish_reason = finish_reason;
void writeChunk(Socket& socket, const std::string data, const bool stop) {
std::ostringstream formattedChunk;
formattedChunk << std::hex << data.size() << "\r\n" << data << "\r\n";
if (stop) {
formattedChunk << "0000\r\n\r\n";
}
else{
socket.write(formattedChunk.str().c_str(), formattedChunk.str().size());
}

void writeChatCompletionChunk(Socket &client_socket, const std::string &delta, const bool stop){
ChunkChoice choice;
if (stop) {
choice.finish_reason = "stop";
} else {
choice.delta = ChatMessageDelta("assistant", delta);
}

ChatCompletionChunk chunk = ChatCompletionChunk(choice);

std::ostringstream oss;

oss << "data: " << ((json)chunk).dump() << "\n\n";

if(finish_reason.size() > 0){
oss << "data: [DONE]\n\n";
}

std::string chunkResponse = oss.str();

// Format the chunked response
std::ostringstream formattedChunk;
formattedChunk << std::hex << chunkResponse.length() << "\r\n" << chunkResponse << "\r\n";
std::ostringstream buffer;
buffer << "data: " << ((json)chunk).dump() << "\r\n\r\n";
writeChunk(client_socket, buffer.str(), false);

client_socket.write(formattedChunk.str().c_str(), formattedChunk.str().length());
if (stop) {
writeChunk(client_socket, "data: [DONE]", true);
}
}

void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
printf("Handling Completion Request\n");
void handleCompletionsRequest(Socket& socket, HttpRequest& request, Inference* inference, Tokenizer* tokenizer, Sampler* sampler, AppArgs* args, TransformerSpec* spec) {
// Set inference arguments
InferenceParams inferParams;
inferParams.temperature = args->temperature;
inferParams.top_p = args->topp;
inferParams.seed = args->seed;
inferParams.stream = false;
inferParams.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
inferParams.max_tokens = spec->seqLen - inferParams.prompt.size();

if(request.parsedJson.contains("stream")){
inferParams.stream = request.parsedJson["stream"].template get<bool>();
InferenceParams params;
params.temperature = args->temperature;
params.top_p = args->topp;
params.seed = args->seed;
params.stream = false;
params.prompt = buildChatPrompt(tokenizer, parseChatMessages(request.parsedJson["messages"]));
params.max_tokens = spec->seqLen - params.prompt.size();

if (request.parsedJson.contains("stream")) {
params.stream = request.parsedJson["stream"].get<bool>();
}
if(request.parsedJson.contains("temperature")){
inferParams.temperature = request.parsedJson["temperature"].template get<float>();
assert(inferParams.temperature >= 0.0f);
sampler->setTemp(inferParams.temperature);
if (request.parsedJson.contains("temperature")) {
params.temperature = request.parsedJson["temperature"].template get<float>();
assert(params.temperature >= 0.0f);
sampler->setTemp(params.temperature);
}
if(request.parsedJson.contains("seed")){
inferParams.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(inferParams.seed);
if (request.parsedJson.contains("seed")) {
params.seed = request.parsedJson["seed"].template get<unsigned long long>();
sampler->setSeed(params.seed);
}
if(request.parsedJson.contains("max_tokens")){
inferParams.max_tokens = request.parsedJson["max_tokens"].template get<int>();
assert(inferParams.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented
if (request.parsedJson.contains("max_tokens")) {
params.max_tokens = request.parsedJson["max_tokens"].template get<int>();
assert(params.max_tokens <= spec->seqLen); //until rope scaling or similiar is implemented
}
if(request.parsedJson.contains("stop")){
inferParams.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
if (request.parsedJson.contains("stop")) {
params.stop = request.parsedJson["stop"].template get<std::vector<std::string>>();
} else {
const std::string defaultStop = "<|eot_id|>";
params.stop = std::vector<std::string>{defaultStop};
}

printf("🔸");
fflush(stdout);

//Process the chat completion request
std::vector<std::string> generated;
generated.get_allocator().allocate(inferParams.max_tokens);
generated.get_allocator().allocate(params.max_tokens);

if (inferParams.stream) {
if (params.stream) {
std::ostringstream oss;
oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: text/event-stream; charset=utf-8\r\n"
<< "Connection: keep-alive\r\n"
<< "Connection: close\r\n"
<< "Transfer-Encoding: chunked\r\n\r\n";

client_socket.write(oss.str().c_str(), oss.str().length());
socket.write(oss.str().c_str(), oss.str().length());
}

int promptLength = inferParams.prompt.length();
int promptLength = params.prompt.length();
int nPromptTokens;
int promptTokens[promptLength + 3];
char prompt[promptLength + 1];
prompt[promptLength] = 0;
strcpy(prompt, inferParams.prompt.c_str());
strcpy(prompt, params.prompt.c_str());
tokenizer->encode(prompt, promptTokens, &nPromptTokens, true, false);

int token = promptTokens[0];
pos_t maxPos = nPromptTokens + inferParams.max_tokens;
pos_t maxPos = nPromptTokens + params.max_tokens;
if (maxPos > spec->seqLen) maxPos = spec->seqLen;
bool eosEncountered = false;
for (pos_t pos = 0; pos < maxPos; pos++) {
Expand All @@ -390,15 +397,15 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer

bool safePiece = isSafePiece(piece);

if (!inferParams.stop.empty() && safePiece) {
if (!params.stop.empty() && safePiece) {
std::string concatenatedTokens;
int startIndex = std::max(0, static_cast<int>(generated.size()) - 7);
for (int i = startIndex; i < generated.size(); ++i) {
concatenatedTokens += generated[i];
}
concatenatedTokens += std::string(piece);

for (const auto& word : inferParams.stop) {
for (const auto& word : params.stop) {
if (concatenatedTokens.find(word) != std::string::npos) {
eosEncountered = true;
break;
Expand All @@ -409,49 +416,56 @@ void handleCompletionsRequest(Socket& client_socket, HttpRequest& request, Infer
if (eosEncountered) break;

std::string string = std::string(piece);

//char string[100];
//strcpy(string, piece);
safePrintf(piece);
fflush(stdout);

generated.push_back(string);

if (inferParams.stream) {
outputChatCompletionChunk(client_socket, string, "");
if (params.stream) {
writeChatCompletionChunk(socket, string, false);
}
}
}

if (!inferParams.stream) {
if (!params.stream) {
ChatMessage chatMessage = ChatMessage("assistant", std::accumulate(generated.begin(), generated.end(), std::string("")));
Choice responseChoice = Choice(chatMessage);
ChatCompletion completion = ChatCompletion(responseChoice);
completion.usage = ChatUsage(nPromptTokens, generated.size(), nPromptTokens + generated.size());

std::string chatJson = ((json)completion).dump();

std::ostringstream oss;

oss << "HTTP/1.1 200 OK\r\n"
<< "Content-Type: application/json; charset=utf-8\r\n"
<< "Content-Length: " << chatJson.length() << "\r\n\r\n" << chatJson;

std::string response = oss.str();

client_socket.write(response.c_str(), response.length());
std::string response = createJsonResponse(chatJson);
socket.write(response.c_str(), response.length());
} else {
outputChatCompletionChunk(client_socket, "", "stop");
writeChatCompletionChunk(socket, "", true);
}
printf("🔶\n");
fflush(stdout);
}

void handleModelsRequest(Socket& client_socket, HttpRequest& request) {
std::string response = createJsonResponse(
"{ \"object\": \"list\","
"\"data\": ["
"{ \"id\": \"dl\", \"object\": \"model\", \"created\": 0, \"owned_by\": \"user\" }"
"] }");
client_socket.write(response.c_str(), response.length());
}

void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer, Sampler *sampler, AppArgs* args, TransformerSpec* spec) {
SocketServer* server = new SocketServer(args->port);
printf("Server URL: http://127.0.0.1:%d/v1/\n", args->port);

std::vector<Route> routes = {
{
"/v1/chat/completions",
HttpMethod::METHOD_POST,
std::bind(&handleCompletionsRequest, std::placeholders::_1, std::placeholders::_2, inference, tokenizer, sampler, args, spec)
},
{
"/v1/models",
HttpMethod::METHOD_GET,
std::bind(&handleModelsRequest, std::placeholders::_1, std::placeholders::_2)
}
};

Expand All @@ -464,7 +478,7 @@ void server(Inference* inference, SocketPool* socketPool, Tokenizer *tokenizer,
// Parse the HTTP request
HttpRequest request = HttpParser::parseRequest(std::string(httpRequest.begin(), httpRequest.end()));
// Handle the HTTP request
printf("New Request: %s %s\n", request.getMethod().c_str(), request.path.c_str());
printf("🔷 %s %s\n", request.getMethod().c_str(), request.path.c_str());
Router::routeRequest(client, request, routes);
} catch (ReadSocketException& ex) {
printf("Read socket error: %d %s\n", ex.code, ex.message);
Expand Down
3 changes: 1 addition & 2 deletions src/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ Socket SocketServer::accept() {
throw std::runtime_error("Error accepting connection");
setNoDelay(clientSocket);
setQuickAck(clientSocket);
printf("Client connected\n");
return Socket(clientSocket);
}

Expand Down Expand Up @@ -261,7 +260,7 @@ bool Socket::tryRead(void* data, size_t size, unsigned long maxAttempts) {

std::vector<char> Socket::readHttpRequest() {
std::vector<char> httpRequest;
char buffer[1024]; // Initial buffer size
char buffer[1024 * 1024]; // TODO: this should be refactored asap
ssize_t bytesRead;

// Peek into the socket buffer to check available data
Expand Down
Loading