diff --git a/.flake8 b/.flake8 index b8f0639..b1142f5 100644 --- a/.flake8 +++ b/.flake8 @@ -1,6 +1,5 @@ [flake8] ignore = E402 -exclude = marimapper/backends/fadecandy/opc.py, venv, marimapper/database.py, marimapper/read_write_model.py +exclude = marimapper/backends/fadecandy/opc.py, venv, marimapper/pycolmap_tools/* max-line-length=127 -max-complexity = 10 -extend-ignore = W503, E203, C901 \ No newline at end of file +extend-ignore = W503, E203 \ No newline at end of file diff --git a/marimapper/backends/fcmega/fcmega_backend.py b/marimapper/backends/fcmega/fcmega_backend.py index 68694a8..8320499 100644 --- a/marimapper/backends/fcmega/fcmega_backend.py +++ b/marimapper/backends/fcmega/fcmega_backend.py @@ -10,12 +10,9 @@ def __init__(self): self.fc_mega = FCMega() self.leds = [(0, 0, 0) for _ in range(self.get_led_count())] self.running = True - self.update_thread = threading.Thread(target=self._run) + self.update_thread = threading.Thread(target=self._run, daemon=True) self.update_thread.start() - def __del__(self): - self.running = False - def get_led_count(self): return 24 * 400 diff --git a/marimapper/camera.py b/marimapper/camera.py index 8180e5c..338c3f9 100644 --- a/marimapper/camera.py +++ b/marimapper/camera.py @@ -36,7 +36,7 @@ def __init__(self, device_id): break if not self.device.isOpened(): - logging.error(f"Failed to connect to camera {device_id}") + raise RuntimeError(f"Failed to connect to camera {device_id}") self.set_resolution(self.get_width(), self.get_height()) # Don't ask @@ -71,7 +71,8 @@ def set_resolution(self, width, height): new_width = self.get_width() new_height = self.get_height() - if width != new_width or height != new_height: + # this is cov ignored as it's a strange position to be in but ultimately fine + if width != new_width or height != new_height: # pragma: no cover logging.error( f"Failed to set camera {self.device_id} resolution to {width} x {height}", ) diff --git a/marimapper/database_populator.py b/marimapper/database_populator.py index b0360b7..ef068f5 100644 --- a/marimapper/database_populator.py +++ b/marimapper/database_populator.py @@ -3,7 +3,7 @@ import numpy as np -from marimapper.database import COLMAPDatabase +from marimapper.pycolmap_tools.database import COLMAPDatabase def populate(db_path, led_maps_2d): diff --git a/marimapper/led_map_3d.py b/marimapper/led_map_3d.py index ec8ef86..b860428 100644 --- a/marimapper/led_map_3d.py +++ b/marimapper/led_map_3d.py @@ -13,13 +13,13 @@ def __init__(self, data=None): if data is not None: self.data = data - def __setitem__(self, led_index, led_data): + def __setitem__(self, led_index, led_data): # pragma: no cover self.data[led_index] = led_data - def __getitem__(self, led_index): + def __getitem__(self, led_index): # pragma: no cover return self.data[led_index] - def __contains__(self, led_index): + def __contains__(self, led_index): # pragma: no cover return led_index in self.data def __len__(self): @@ -64,9 +64,6 @@ def rescale(self, target_inter_distance=1.0): for cam in self.cameras: cam[1] *= scale - def get_normal_list(self): - return np.array([self[led_id]["normal"] for led_id in self.keys()]) - def get_inter_led_distance(self): max_led_id = max(self.keys()) diff --git a/marimapper/map_cleaner.py b/marimapper/map_cleaner.py index dbfda93..71217c2 100644 --- a/marimapper/map_cleaner.py +++ b/marimapper/map_cleaner.py @@ -72,44 +72,3 @@ def fill_gaps(led_map, max_dist_err=0.2, max_missing=5): total_leds_filled += leds_missing return total_leds_filled - - -def extract_strips(led_map, max_dist_err=0.5): - - max_led_id = max(led_map.keys()) - - strips = [[]] - led_id = -1 - - led_to_led_distance = find_inter_led_distance(led_map) - - max_dist = (1 + max_dist_err) * led_to_led_distance - min_dist = (1 - max_dist_err) * led_to_led_distance - - while True: - - led_id += 1 - - if led_id > max_led_id: - break - - if led_id not in led_map: - continue - - if led_id + 1 not in led_map: - if strips[-1]: - strips[-1].append(led_id) - strips.append([]) - continue - - distance = _distance_between_leds(led_map[led_id], led_map[led_id + 1]) - - if not (min_dist < distance < max_dist): - if strips[-1]: - strips[-1].append(led_id) - strips.append([]) - continue - - strips[-1].append(led_id) - - return strips diff --git a/marimapper/model.py b/marimapper/model.py index 69ddfcc..d735505 100644 --- a/marimapper/model.py +++ b/marimapper/model.py @@ -2,7 +2,7 @@ import numpy as np -from marimapper.read_write_model import ( +from marimapper.pycolmap_tools.read_write_model import ( qvec2rotmat, read_images_binary, read_points3D_binary, diff --git a/marimapper/pycolmap_tools/__init__.py b/marimapper/pycolmap_tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/marimapper/database.py b/marimapper/pycolmap_tools/database.py similarity index 99% rename from marimapper/database.py rename to marimapper/pycolmap_tools/database.py index 1e15a6d..1b12781 100644 --- a/marimapper/database.py +++ b/marimapper/pycolmap_tools/database.py @@ -1,4 +1,3 @@ -# fmt: off # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. # All rights reserved. # @@ -429,5 +428,3 @@ def example_usage(): if __name__ == "__main__": example_usage() - -# fmt: on diff --git a/marimapper/read_write_model.py b/marimapper/pycolmap_tools/read_write_model.py similarity index 99% rename from marimapper/read_write_model.py rename to marimapper/pycolmap_tools/read_write_model.py index eec9b08..037feb2 100644 --- a/marimapper/read_write_model.py +++ b/marimapper/pycolmap_tools/read_write_model.py @@ -1,4 +1,3 @@ -# fmt: off # Copyright (c) 2023, ETH Zurich and UNC Chapel Hill. # All rights reserved. # @@ -604,4 +603,3 @@ def main(): if __name__ == "__main__": main() -# fmt: on diff --git a/marimapper/utils.py b/marimapper/utils.py index 1051139..8f87592 100644 --- a/marimapper/utils.py +++ b/marimapper/utils.py @@ -61,7 +61,7 @@ def add_backend_args(parser): parser.add_argument("--server", type=str, help="Some backends require a server") -def get_user_confirmation(prompt): +def get_user_confirmation(prompt): # pragma: no coverage try: uin = input(logging.colorise(prompt, logging.Col.BLUE)) @@ -86,13 +86,21 @@ def load_custom_backend(backend_file, server=None): backend = custom_backend.Backend(server) if server else custom_backend.Backend() - if "get_led_count" not in dir(backend): - raise RuntimeError("Your backend does not have a get_led_count function") + check_backend(backend) - if "set_led" not in dir(backend): - raise RuntimeError("Your backend does not have a set_led function") + return backend + + +def check_backend(backend): - if len(signature(backend.get_led_count).parameters) != 0: + missing_funcs = {"get_led_count", "set_led"}.difference(set(dir(backend))) + + if missing_funcs: + raise RuntimeError( + f"Your backend does not have the following functions: {missing_funcs}" + ) + + if len(signature(backend.get_led_count).parameters) != 0: # pragma: no coverage raise RuntimeError( "Your backend get_led_count function should not take any arguments" ) @@ -102,8 +110,6 @@ def load_custom_backend(backend_file, server=None): "Your backend set_led function should only take two arguments" ) - return backend - def get_backend(backend_name, server=""): if backend_name == "fadecandy": diff --git a/marimapper/visualize_model.py b/marimapper/visualize_model.py index a6fb19f..d555b88 100644 --- a/marimapper/visualize_model.py +++ b/marimapper/visualize_model.py @@ -101,7 +101,7 @@ def reload_geometry__(self, first=False): np.array([led_map.data[led_id]["pos"] for led_id in led_map.keys()]) ) self.point_cloud.normals = open3d.utility.Vector3dVector( - led_map.get_normal_list() * 0.2 + np.array([led_map[led_id]["normal"] for led_id in led_map.keys()]) * 0.2 ) self.strip_set.points = self.point_cloud.points diff --git a/pyproject.toml b/pyproject.toml index e0fa11f..7242f51 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,10 +35,11 @@ readme = "README.md" license = {file = "LICENSE"} - [project.optional-dependencies] develop = [ "pytest", + "pytest-cov", + "pytest-mock", "black", "flake8", "flake8-bugbear" @@ -62,5 +63,14 @@ omit = [ "*/__main__.py", "marimapper/backends/*", "marimapper/scripts/*", - "marimapper/read_write_model.py" - ] \ No newline at end of file + "marimapper/pycolmap_tools/*" + ] + +[tool.black] +exclude = ''' +(/( + | venv + | marimapper/backends + | marimapper/pycolmap_tools +)/) +''' \ No newline at end of file diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 0000000..3209f37 --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,167 @@ +import pytest +import tempfile + +from marimapper.utils import get_backend + + +def test_basic_usage(): + + temp_backend_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") + temp_backend_file.write( + b""" +class Backend: + + def __init__(self): + pass + + def get_led_count(self): + return 1 + + def set_led(self, led_index, on): + pass +""" + ) + temp_backend_file.close() + backend = get_backend(temp_backend_file.name) + + assert backend.get_led_count() == 1 + + +def test_invalid_backend_due_to_missing_function(): + temp_backend_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") + temp_backend_file.write( + b""" +class Backend: + + def __init__(self): + pass + + def get_leds(self): # Should be get_led_count() + return 1 + + def set_led(self, led_index, on): + pass + """ + ) + temp_backend_file.close() + + with pytest.raises(RuntimeError): + get_backend(temp_backend_file.name) + + +def test_invalid_backend_due_to_missing_function_arguments(): + temp_backend_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") + temp_backend_file.write( + b""" +class Backend: + + def __init__(self): + pass + + def get_led_count(self) -> int: + return 1 + + def set_led(self, led_index): # this is missing the on parameter + pass + """ + ) + temp_backend_file.close() + + with pytest.raises(RuntimeError): + get_backend(temp_backend_file.name) + + +def test_fadecandy(monkeypatch): + + from marimapper.backends.fadecandy import opc + + def get_client_patch(uri): + class ClientPatch: + def put_pixels(self, _): + pass + + return ClientPatch() + + monkeypatch.setattr(opc, "Client", get_client_patch) + + get_backend("fadecandy") + get_backend("fadecandy", "1.2.3.4") + + +def test_wled(monkeypatch): + + import requests + + def return_response_patch(*arg, **kwargs): + class ResponsePatch: + status_code = 200 + + def json(self): + return {"leds": {"count": 1}} + + return ResponsePatch() + + monkeypatch.setattr(requests, "post", return_response_patch) + monkeypatch.setattr(requests, "get", return_response_patch) + + get_backend("wled") + get_backend("wled", "1.2.3.4") + + +def test_fcmega(monkeypatch): + + import serial + import serial.tools.list_ports + + class SerialPatch: + + def __init__(self, _): + self.is_open = True + + def write(self, _): + pass + + def read(self, _): + return b"1" + + def comports_patch(): + class ComportPatch: + serial_number = "FCM000" + name = "10" + + return [ComportPatch()] + + monkeypatch.setattr(serial, "Serial", SerialPatch) + monkeypatch.setattr(serial.tools.list_ports, "comports", comports_patch) + + get_backend("fcmega") + + +def test_pixelblaze(monkeypatch): + + # mini_racer uses import pkg_resources which is depreciated + with pytest.warns(DeprecationWarning): + + import pixelblaze + + class PixelblazePatch: + def __init__(self, _): + pass + + def setActivePatternByName(self, _): + pass + + def getPixelCount(self): + return 1 + + monkeypatch.setattr(pixelblaze, "Pixelblaze", PixelblazePatch) + + get_backend("pixelblaze", "1.2.3.4") + + +def test_invalid_or_none_backend(): + + assert get_backend("None") is None + + with pytest.raises(RuntimeError): + get_backend("invalid_backend") diff --git a/test/test_camera.py b/test/test_camera.py new file mode 100644 index 0000000..f60cf74 --- /dev/null +++ b/test/test_camera.py @@ -0,0 +1,21 @@ +import pytest +from marimapper.camera import Camera + + +def test_valid_device(): + + cam = Camera("test/MariMapper-Test-Data/9_point_box/cam_0/capture_0000.png") + + image = cam.read() + + assert image.shape == (480, 640) # Grey + + image_bw = cam.read(color=True) + + assert image_bw.shape == (480, 640, 4) # RGBA + + +def test_invalid_device(): + + with pytest.raises(RuntimeError): + Camera(device_id="bananas") diff --git a/test/test_custom_backend.py b/test/test_custom_backend.py deleted file mode 100644 index fee53f7..0000000 --- a/test/test_custom_backend.py +++ /dev/null @@ -1,49 +0,0 @@ -import pytest -import tempfile - -from marimapper.utils import load_custom_backend - - -def test_basic_usage(): - - temp_backend_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") - temp_backend_file.write( - b""" -class Backend: - - def __init__(self): - pass - - def get_led_count(self) -> int: - return 1 - - def set_led(self, led_index: int, on: bool) -> None: - pass -""" - ) - temp_backend_file.close() - backend = load_custom_backend(temp_backend_file.name) - - assert backend.get_led_count() == 1 - - -def test_invalid_backend(): - temp_backend_file = tempfile.NamedTemporaryFile(delete=False, suffix=".py") - temp_backend_file.write( - b""" -class Backend: - - def __init__(self): - pass - - def get_leds(self) -> int: # Should be get_led_count() - return 1 - - def set_led(self, led_index: int, on: bool) -> None: - pass - """ - ) - temp_backend_file.close() - - with pytest.raises(RuntimeError): - load_custom_backend(temp_backend_file.name) diff --git a/test/test_reconstruction.py b/test/test_reconstruction.py index c8de6dc..0598f0a 100644 --- a/test/test_reconstruction.py +++ b/test/test_reconstruction.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from marimapper.sfm import SFM from marimapper.led_map_2d import get_all_2d_led_maps @@ -87,3 +88,34 @@ def test_reconstruct_higbeam(): map_3d = SFM.process__(highbeam_map) assert map_3d is not None + + +# this test does a re-scale, but should keep the dimensions about the same +def test_rescale(): + maps = get_all_2d_led_maps("test/scan") + + map_3d = SFM.process__(maps, rescale=True) + + assert map_3d.get_inter_led_distance() == pytest.approx(1.0) + + +def test_connected(): + + maps = get_all_2d_led_maps("test/scan") + + map_3d = SFM.process__(maps) + + assert len(map_3d) == 21 + + connected = map_3d.get_connected_leds() + assert (6, 7) not in connected + assert (13, 14) not in connected + + +def test_interpolate(): + + maps = get_all_2d_led_maps("test/scan") + + map_3d = SFM.process__(maps, interpolate=True) + + assert len(map_3d) == 23