From 3a23584a1e2f07da7c0fc9694e02fd92ff1b94a4 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 30 Dec 2023 22:02:09 +1100 Subject: [PATCH 01/16] [server] Rename Blueprint to SubServer and create Blueprint alternative --- docs/blueprints.md | 54 --------- docs/error-handling.md | 5 +- docs/middleware.md | 7 +- docs/subserver-blueprint.md | 98 +++++++++++++++ mkdocs.yml | 2 +- pyproject.toml | 2 +- src/nserver/__init__.py | 2 +- src/nserver/server.py | 232 +++++++++++++++++++++--------------- tests/test_blueprint.py | 132 +------------------- tests/test_subserver.py | 186 +++++++++++++++++++++++++++++ 10 files changed, 440 insertions(+), 280 deletions(-) delete mode 100644 docs/blueprints.md create mode 100644 docs/subserver-blueprint.md create mode 100644 tests/test_subserver.py diff --git a/docs/blueprints.md b/docs/blueprints.md deleted file mode 100644 index 90462ec..0000000 --- a/docs/blueprints.md +++ /dev/null @@ -1,54 +0,0 @@ -# Blueprints - -[`Blueprint`][nserver.server.Blueprint]s provide a way for you to compose your application. They support most of the same functionality as a `NameServer`. - -Use cases: - -- Split up your application across different blueprints for maintainability / composability. -- Reuse a blueprint registered under different rules. -- Allow custom packages to define their own rules that you can add to your own server. - -Blueprints require `nserver>=2.0` - -## Using Blueprints - -```python -from nserver import Blueprint, NameServer, ZoneRule, ALL_CTYPES, A - -# First Blueprint -mysite = Blueprint("mysite") - -@mysite.rule("nicholashairs.com", ["A"]) -@mysite.rule("www.nicholashairs.com", ["A"]) -def nicholashairs_website(query: Query) -> A: - return A(query.name, "159.65.13.73") - -@mysite.rule(ZoneRule, "", ALL_CTYPES) -def nicholashairs_catchall(query: Query) -> None: - # Return empty response for all other queries - return None - -# Second Blueprint -en_blueprint = Blueprint("english-speaking-blueprint") - -@en_blueprint.rule("hello.{base_domain}", ["A"]) -def en_hello(query: Query) -> A: - return A(query.name, "1.1.1.1") - -# Register to NameServer -server = NameServer("server") -server.register_blueprint(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "au", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "nz", ALL_CTYPES) -server.register_blueprint(en_blueprint, ZoneRule, "uk", ALL_CTYPES) -``` - -### Middleware, Hooks, and Error Handling - -Blueprints maintain their own `QueryMiddleware` stack which will run before any rule function is run. Included in this stack is the `HookMiddleware` and `ExceptionHandlerMiddleware`. - -## Key differences with `NameServer` - -- Does not use settings (`Setting`). -- Does not have a `Transport`. -- Does not have a `RawRecordMiddleware` stack. diff --git a/docs/error-handling.md b/docs/error-handling.md index a27c2e5..86243dc 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -2,7 +2,8 @@ Custom exception handling is handled through the [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] and [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. -Error handling requires `nserver>=2.0` +!!! note + Error handling requires `nserver>=2.0` In general you are probably able to use the `ExceptionHandlerMiddleware` as the `RawRecordExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawRecordMiddleware` or broken exception handlers in the `ExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `ExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `ExceptionHandlerMiddleware`. @@ -15,6 +16,8 @@ Handlers are chosen by finding a handler for the most specific parent class of t ## Registering Exception Handlers +Exception handlers can be registered to `NameServer` and `SubServer` instances using either their `@[raw_]exception_handler` decorators or their `register_[raw_]exception_handler` methods. + ```python import dnslib from nserver import NameServer, Query, Response diff --git a/docs/middleware.md b/docs/middleware.md index c54fc85..41987ff 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -2,7 +2,8 @@ Middleware can be used to modify the behaviour of a server seperate to the individual rules that are registered to the server. Middleware is run on all requests and can modify both the input and response of a request. -Middleware requires `nserver>=2.0` +!!! note + Middleware requires `nserver>=2.0` ## Middleware Stacks @@ -18,6 +19,8 @@ For most use cases you likely want to use [`QueryMiddleware`][nserver.middleware ### Registering `QueryMiddleware` +`QueryMiddleware` can be registered to `NameServer` and `SubServer` instances using their `register_middleware` methods. + ```python from nserver import NameServer from nserver.middleware import QueryMiddleware @@ -72,6 +75,8 @@ Once processed the `QueryMiddleware` stack will look as follows: ### Registering `RawRecordMiddleware` +`RawRecordMiddleware` can be registered to `NameServer` instances using their `register_raw_middleware` method. + ```python # ... from nserver.middleware import RawRecordMiddleware diff --git a/docs/subserver-blueprint.md b/docs/subserver-blueprint.md new file mode 100644 index 0000000..c880566 --- /dev/null +++ b/docs/subserver-blueprint.md @@ -0,0 +1,98 @@ +# Sub-Servers and Blueprints + + +## Sub-Servers + +[`SubServer`][nserver.server.SubServer] provides a way for you to compose your application. They support most of the same functionality as a `NameServer`. + +Use cases: + +- Split up your application across different servers for maintainability / composability. +- Reuse a server registered under different rules. +- Allow custom packages to define their own rules that you can add to your own server. + +!!! note + SubServers requires `nserver>=2.0` + +### Using Sub-Servers + +```python +from nserver import SubServer, NameServer, ZoneRule, ALL_CTYPES, A, TXT + +# First SubServer +mysite = SubServer("mysite") + +@mysite.rule("nicholashairs.com", ["A"]) +@mysite.rule("www.nicholashairs.com", ["A"]) +def nicholashairs_website(query: Query) -> A: + return A(query.name, "159.65.13.73") + +@mysite.rule(ZoneRule, "", ALL_CTYPES) +def nicholashairs_catchall(query: Query) -> None: + # Return empty response for all other queries + return None + +# Second SubServer +en_subserver = SubServer("english-speaking-blueprint") + +@en_subserver.rule("hello.{base_domain}", ["TXT"]) +def en_hello(query: Query) -> TXT: + return TXT(query.name, "Hello There!") + +# Register to NameServer +server = NameServer("server") +server.register_subserver(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "au", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "nz", ALL_CTYPES) +server.register_subserver(en_subserver, ZoneRule, "uk", ALL_CTYPES) +``` + +#### Middleware, Hooks, and Error Handling + +Sub-Servers maintain their own `QueryMiddleware` stack which will run before any rule function is run. Included in this stack is the `HookMiddleware` and `ExceptionHandlerMiddleware`. + +### Key differences with `NameServer` + +- Does not use settings (`Setting`). +- Does not have a `Transport`. +- Does not have a `RawRecordMiddleware` stack. + +## Blueprints + +[`Blueprint`][nserver.server.Blueprint]s act as a container for rules. They are an efficient way to compose your application if you do not want or need to use functionality provided by a `QueryMiddleware` stack. + +!!! note + Blueprints require `nserver>=2.0` + +### Using Blueprints + +```python +# ... +from nserver import Blueprint, MX + +no_email_blueprint = Blueprint("noemail") + +@no_email_blueprint.rule("{base_domain}", ["MX"]) +@no_email_blueprint.rule("**.{base_domain}", ["MX"]) +def no_email(query: Query) -> MX: + "Indicate that we do not have a mail exchange" + return MX(query.name, ".", 0) + + +## Add it to our sub-servers +en_subserver.register_rule(no_email_blueprint) + +# Problem! Because we have already registered the nicholashairs_catchall rule, +# it will prevent our blueprint from being called. So instead let's manually +# insert it as the first rule. +mysite.rules.insert(0, no_email_blueprint) +``` + +### Key differences with `NameServer` and `SubServer` + +- Only provides the `@rule` decorator and `register_rule` method. + - It does not have a `QueryMiddleware` stack which means it does not support hooks or error-handling. +- Is used directly in `register_rule` (e.g. `some_server.register_rule(my_blueprint)`). +- If rule does not match an internal rule will continue to the next rule in the parent server. + + In comparison the server classes will return `NXDOMAIN` if a rule doesn't match their internal rules. diff --git a/mkdocs.yml b/mkdocs.yml index fa88c1d..2524ccf 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -13,7 +13,7 @@ nav: - quickstart.md - middleware.md - error-handling.md - - blueprints.md + - subserver-blueprint.md - production-deployment.md - changelog.md - external-resources.md diff --git a/pyproject.toml b/pyproject.toml index 7c9a83b..9ff5a2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "nserver" -version = "2.0.0" +version = "2.0.1.rc1" description = "DNS Name Server Framework" authors = [ {name = "Nicholas Hairs", email = "info+nserver@nicholashairs.com"}, diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index 54e8c75..d4fef35 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,5 +1,5 @@ from .models import Query, Response from .rules import ALL_QTYPES, StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA -from .server import NameServer, Blueprint +from .server import NameServer, SubServer, Blueprint from .settings import Settings diff --git a/src/nserver/server.py b/src/nserver/server.py index 8c69b61..0482914 100644 --- a/src/nserver/server.py +++ b/src/nserver/server.py @@ -29,27 +29,114 @@ ### Classes ### ============================================================================ -class Scaffold: - """Base class for shared functionality between `NameServer` and `Blueprint` +class _LoggingMixin: # pylint: disable=too-few-public-methods + """Self bound logging methods""" + + _logger: logging.Logger + + def _vvdebug(self, *args, **kwargs): + """Log very verbose debug message.""" + + return self._logger.log(6, *args, **kwargs) + + def _vdebug(self, *args, **kwargs): + """Log verbose debug message.""" + + return self._logger.log(8, *args, **kwargs) + + def _debug(self, *args, **kwargs): + """Log debug message.""" + + return self._logger.debug(*args, **kwargs) + + def _info(self, *args, **kwargs): + """Log very verbose debug message.""" + + return self._logger.info(*args, **kwargs) + + def _warning(self, *args, **kwargs): + """Log warning message.""" + + return self._logger.warning(*args, **kwargs) + + def _error(self, *args, **kwargs): + """Log an error message.""" + + return self._logger.error(*args, **kwargs) + + def _critical(self, *args, **kwargs): + """Log a critical message.""" + + return self._logger.critical(*args, **kwargs) + + +class RulesContainer(_LoggingMixin): + """Base class for rules based functionality` New in `2.0`. Attributes: rules: registered rules + """ + + def __init__(self) -> None: + super().__init__() + self.rules: List[RuleBase] = [] + return + + def register_rule(self, rule: RuleBase) -> None: + """Register the given rule + + Args: + rule: the rule to register + """ + self._debug(f"Registered rule: {rule!r}") + self.rules.append(rule) + return + + def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): + """Decorator for registering a function using [`smart_make_rule`][nserver.rules.smart_make_rule]. + + Changed in `2.0`: This method now uses `smart_make_rule`. + + Args: + rule_: rule as per `nserver.rules.smart_make_rule` + args: extra arguments to provide `smart_make_rule` + kwargs: extra keyword arguments to provide `smart_make_rule` + + Raises: + ValueError: if `func` is provided in `kwargs`. + """ + + if "func" in kwargs: + raise ValueError("Must not provide `func` in kwargs") + + def decorator(func: ResponseFunction): + nonlocal rule_ + nonlocal args + nonlocal kwargs + self.register_rule(smart_make_rule(rule_, *args, func=func, **kwargs)) + return func + + return decorator + + +class ServerBase(RulesContainer): + """Base class for shared functionality between `NameServer` and `SubServer` + + New in `2.0`. + + Attributes: hook_middleware: hook middleware exception_handler_middleware: Query exception handler middleware """ - _logger: logging.Logger - - def __init__(self, name: str) -> None: + def __init__(self) -> None: """ Args: name: The name of the server. This is used for internal logging. """ - self.name = name - - self.rules: List[RuleBase] = [] + super().__init__() self.hook_middleware = middleware.HookMiddleware() self.exception_handler_middleware = middleware.ExceptionHandlerMiddleware() @@ -61,25 +148,15 @@ def __init__(self, name: str) -> None: ## Register Methods ## ------------------------------------------------------------------------- - def register_rule(self, rule: RuleBase) -> None: - """Register the given rule - - Args: - rule: the rule to register - """ - self._debug(f"Registered rule: {rule!r}") - self.rules.append(rule) - return - - def register_blueprint( - self, blueprint: "Blueprint", rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs + def register_subserver( + self, subserver: "SubServer", rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs ) -> None: - """Register a blueprint using [`smart_make_rule`][nserver.rules.smart_make_rule]. + """Register a `SubServer` using [`smart_make_rule`][nserver.rules.smart_make_rule]. New in `2.0`. Args: - blueprint: the `Blueprint` to attach + subserver: the `SubServer` to attach rule_: rule as per `nserver.rules.smart_make_rule` args: extra arguments to provide `smart_make_rule` kwargs: extra keyword arguments to provide `smart_make_rule` @@ -90,7 +167,7 @@ def register_blueprint( if "func" in kwargs: raise ValueError("Must not provide `func` in kwargs") - self.register_rule(smart_make_rule(rule_, *args, func=blueprint.entrypoint, **kwargs)) + self.register_rule(smart_make_rule(rule_, *args, func=subserver.entrypoint, **kwargs)) return def register_before_first_query(self, func: middleware.BeforeFirstQueryHook) -> None: @@ -158,32 +235,6 @@ def register_exception_handler( # Decorators # .......................................................................... - def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): - """Decorator for registering a function using [`smart_make_rule`][nserver.rules.smart_make_rule]. - - Changed in `2.0`: This method now uses `smart_make_rule`. - - Args: - rule_: rule as per `nserver.rules.smart_make_rule` - args: extra arguments to provide `smart_make_rule` - kwargs: extra keyword arguments to provide `smart_make_rule` - - Raises: - ValueError: if `func` is provided in `kwargs`. - """ - - if "func" in kwargs: - raise ValueError("Must not provide `func` in kwargs") - - def decorator(func: ResponseFunction): - nonlocal rule_ - nonlocal args - nonlocal kwargs - self.register_rule(smart_make_rule(rule_, *args, func=func, **kwargs)) - return func - - return decorator - def before_first_query(self): """Decorator for registering before_first_query hook. @@ -266,45 +317,8 @@ def _prepare_query_middleware_stack(self) -> None: self._query_middleware_stack.append(rule_processor) return - ## Logging - ## ------------------------------------------------------------------------- - def _vvdebug(self, *args, **kwargs): - """Log very verbose debug message.""" - - return self._logger.log(6, *args, **kwargs) - - def _vdebug(self, *args, **kwargs): - """Log verbose debug message.""" - - return self._logger.log(8, *args, **kwargs) - - def _debug(self, *args, **kwargs): - """Log debug message.""" - - return self._logger.debug(*args, **kwargs) - def _info(self, *args, **kwargs): - """Log very verbose debug message.""" - - return self._logger.info(*args, **kwargs) - - def _warning(self, *args, **kwargs): - """Log warning message.""" - - return self._logger.warning(*args, **kwargs) - - def _error(self, *args, **kwargs): - """Log an error message.""" - - return self._logger.error(*args, **kwargs) - - def _critical(self, *args, **kwargs): - """Log a critical message.""" - - return self._logger.critical(*args, **kwargs) - - -class NameServer(Scaffold): +class NameServer(ServerBase): """NameServer for responding to requests.""" # pylint: disable=too-many-instance-attributes @@ -315,8 +329,9 @@ def __init__(self, name: str, settings: Optional[Settings] = None) -> None: name: The name of the server. This is used for internal logging. settings: settings to use with this `NameServer` instance """ - super().__init__(name) - self._logger = logging.getLogger(f"nserver.i.{self.name}") + super().__init__() + self.name = name + self._logger = logging.getLogger(f"nserver.i.nameserver.{self.name}") self.raw_exception_handler_middleware = middleware.RawRecordExceptionHandlerMiddleware() self._user_raw_record_middleware: List[middleware.RawRecordMiddleware] = [] @@ -514,11 +529,13 @@ def _prepare_raw_record_middleware_stack(self) -> None: return -class Blueprint(Scaffold): +class SubServer(ServerBase): """Class that can replicate many of the functions of a `NameServer`. They can be used to construct or extend applications. + A `SubServer` maintains it's own `QueryMiddleware` stack and list of rules. + New in `2.0`. """ @@ -527,15 +544,44 @@ def __init__(self, name: str) -> None: Args: name: The name of the server. This is used for internal logging. """ - super().__init__(name) - self._logger = logging.getLogger(f"nserver.b.{self.name}") + super().__init__() + self.name = name + self._logger = logging.getLogger(f"nserver.i.subserver.{self.name}") return def entrypoint(self, query: Query) -> Response: - """Entrypoint into this `Blueprint`. + """Entrypoint into this `SubServer`. This method should be passed to rules as the function to run. """ if not self._query_middleware_stack: self._prepare_query_middleware_stack() return self._query_middleware_stack[0](query) + + +class Blueprint(RulesContainer, RuleBase): + """A container for rules that can be registered onto a server + + It can be registered as normal rule: `server.register_rule(blueprint_rule)` + + New in `2.0`. + """ + + def __init__(self, name: str) -> None: + """ + Args: + name: The name of the server. This is used for internal logging. + """ + super().__init__() + self.name = name + self._logger = logging.getLogger(f"nserver.i.blueprint.{self.name}") + return + + def get_func(self, query: Query) -> Optional[ResponseFunction]: + for rule in self.rules: + func = rule.get_func(query) + if func is not None: + self._debug(f"matched {rule}") + return func + self._debug("did not match any rule") + return None diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index e801dc8..ed9df03 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -3,15 +3,11 @@ ### IMPORTS ### ============================================================================ ## Standard Library -from typing import no_type_check, List -import unittest.mock - ## Installed import dnslib import pytest -from nserver import NameServer, Blueprint, Query, Response, ALL_QTYPES, ZoneRule, A -from nserver.server import Scaffold +from nserver import NameServer, Blueprint, Query, A ## Application @@ -34,98 +30,11 @@ def dummy_rule(query: Query) -> A: return A(query.name, IP) -## Hooks -## ----------------------------------------------------------------------------- -def register_hooks(scaff: Scaffold) -> None: - scaff.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) - scaff.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) - scaff.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) - return - - -@no_type_check -def reset_hooks(scaff: Scaffold) -> None: - scaff.hook_middleware.before_first_query_run = False - scaff.hook_middleware.before_first_query[0].reset_mock() - scaff.hook_middleware.before_query[0].reset_mock() - scaff.hook_middleware.after_query[0].reset_mock() - return - - -def reset_all_hooks() -> None: - reset_hooks(server) - reset_hooks(blueprint_1) - reset_hooks(blueprint_2) - reset_hooks(blueprint_3) - return - - -@no_type_check -def check_hook_call_count(scaff: Scaffold, bfq_count: int, bq_count: int, aq_count: int) -> None: - assert scaff.hook_middleware.before_first_query[0].call_count == bfq_count - assert scaff.hook_middleware.before_query[0].call_count == bq_count - assert scaff.hook_middleware.after_query[0].call_count == aq_count - return - - -register_hooks(server) -register_hooks(blueprint_1) -register_hooks(blueprint_2) -register_hooks(blueprint_3) - - -## Exception handling -## ----------------------------------------------------------------------------- -class ErrorForTesting(Exception): - pass - - -@server.rule("throw-error.com", ["A"]) -def throw_error(query: Query) -> None: - raise ErrorForTesting() - - -def _query_error_handler(query: Query, exception: Exception) -> Response: - # pylint: disable=unused-argument - return Response(error_code=dnslib.RCODE.SERVFAIL) - - -query_error_handler = unittest.mock.MagicMock(wraps=_query_error_handler) -server.register_exception_handler(ErrorForTesting, query_error_handler) - - -class ThrowAnotherError(Exception): - pass - - -@server.rule("throw-another-error.com", ["A"]) -def throw_another_error(query: Query) -> None: - raise ThrowAnotherError() - - -def bad_error_handler(query: Query, exception: Exception) -> Response: - # pylint: disable=unused-argument - raise ErrorForTesting() - - -server.register_exception_handler(ThrowAnotherError, bad_error_handler) - - -def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> dnslib.DNSRecord: - # pylint: disable=unused-argument - response = record.reply() - response.header.rcode = dnslib.RCODE.SERVFAIL - return response - - -raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) -server.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) - ## Get server ready ## ----------------------------------------------------------------------------- -server.register_blueprint(blueprint_1, ZoneRule, "b1.com", ALL_QTYPES) -server.register_blueprint(blueprint_2, ZoneRule, "b2.com", ALL_QTYPES) -blueprint_2.register_blueprint(blueprint_3, ZoneRule, "b3.b2.com", ALL_QTYPES) +server.register_rule(blueprint_1) +server.register_rule(blueprint_2) +blueprint_2.register_rule(blueprint_3) server._prepare_middleware_stacks() @@ -149,36 +58,3 @@ def test_nxdomain(question: str): assert len(response.rr) == 0 assert response.header.rcode == dnslib.RCODE.NXDOMAIN return - - -## Hooks -## ----------------------------------------------------------------------------- -@pytest.mark.parametrize( - "question,hook_counts", - [ - ("s.com", [1, 5, 5]), - ("b1.com", [1, 5, 5, 1, 5, 5]), - ("b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5]), - ("b3.b2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5, 1, 5, 5]), - ], -) -def test_hooks(question: str, hook_counts: List[int]): - ## Setup - # fill unset hook_counts - hook_counts += [0] * (12 - len(hook_counts)) - assert len(hook_counts) == 12 - # reset hooks - reset_all_hooks() - - ## Test - for _ in range(5): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) - assert len(response.rr) == 1 - assert response.rr[0].rtype == 1 - assert response.rr[0].rname == question - - check_hook_call_count(server, *hook_counts[:3]) - check_hook_call_count(blueprint_1, *hook_counts[3:6]) - check_hook_call_count(blueprint_2, *hook_counts[6:9]) - check_hook_call_count(blueprint_3, *hook_counts[9:]) - return diff --git a/tests/test_subserver.py b/tests/test_subserver.py new file mode 100644 index 0000000..601e51e --- /dev/null +++ b/tests/test_subserver.py @@ -0,0 +1,186 @@ +# pylint: disable=missing-class-docstring,missing-function-docstring,protected-access + +### IMPORTS +### ============================================================================ +## Standard Library +from typing import no_type_check, List +import unittest.mock + +## Installed +import dnslib +import pytest + +from nserver import NameServer, SubServer, Query, Response, ALL_QTYPES, ZoneRule, A +from nserver.server import ServerBase + +## Application + +### SETUP +### ============================================================================ +IP = "127.0.0.1" +nameserver = NameServer("test_subserver") +subserver_1 = SubServer("subserver_1") +subserver_2 = SubServer("subserver_2") +subserver_3 = SubServer("subserver_3") + + +## Rules +## ----------------------------------------------------------------------------- +@nameserver.rule("s.com", ["A"]) +@subserver_1.rule("sub1.com", ["A"]) +@subserver_2.rule("sub2.com", ["A"]) +@subserver_3.rule("sub3.sub2.com", ["A"]) +def dummy_rule(query: Query) -> A: + return A(query.name, IP) + + +## Hooks +## ----------------------------------------------------------------------------- +def register_hooks(server: ServerBase) -> None: + server.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) + server.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) + server.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) + return + + +@no_type_check +def reset_hooks(server: ServerBase) -> None: + server.hook_middleware.before_first_query_run = False + server.hook_middleware.before_first_query[0].reset_mock() + server.hook_middleware.before_query[0].reset_mock() + server.hook_middleware.after_query[0].reset_mock() + return + + +def reset_all_hooks() -> None: + reset_hooks(nameserver) + reset_hooks(subserver_1) + reset_hooks(subserver_2) + reset_hooks(subserver_3) + return + + +@no_type_check +def check_hook_call_count(server: ServerBase, bfq_count: int, bq_count: int, aq_count: int) -> None: + assert server.hook_middleware.before_first_query[0].call_count == bfq_count + assert server.hook_middleware.before_query[0].call_count == bq_count + assert server.hook_middleware.after_query[0].call_count == aq_count + return + + +register_hooks(nameserver) +register_hooks(subserver_1) +register_hooks(subserver_2) +register_hooks(subserver_3) + + +## Exception handling +## ----------------------------------------------------------------------------- +class ErrorForTesting(Exception): + pass + + +@nameserver.rule("throw-error.com", ["A"]) +def throw_error(query: Query) -> None: + raise ErrorForTesting() + + +def _query_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + return Response(error_code=dnslib.RCODE.SERVFAIL) + + +query_error_handler = unittest.mock.MagicMock(wraps=_query_error_handler) +nameserver.register_exception_handler(ErrorForTesting, query_error_handler) + + +class ThrowAnotherError(Exception): + pass + + +@nameserver.rule("throw-another-error.com", ["A"]) +def throw_another_error(query: Query) -> None: + raise ThrowAnotherError() + + +def bad_error_handler(query: Query, exception: Exception) -> Response: + # pylint: disable=unused-argument + raise ErrorForTesting() + + +nameserver.register_exception_handler(ThrowAnotherError, bad_error_handler) + + +def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> dnslib.DNSRecord: + # pylint: disable=unused-argument + response = record.reply() + response.header.rcode = dnslib.RCODE.SERVFAIL + return response + + +raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) +nameserver.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) + +## Get server ready +## ----------------------------------------------------------------------------- +nameserver.register_subserver(subserver_1, ZoneRule, "sub1.com", ALL_QTYPES) +nameserver.register_subserver(subserver_2, ZoneRule, "sub2.com", ALL_QTYPES) +subserver_2.register_subserver(subserver_3, ZoneRule, "sub3.sub2.com", ALL_QTYPES) + +nameserver._prepare_middleware_stacks() + + +### TESTS +### ============================================================================ +## Responses +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize("question", ["s.com", "sub1.com", "sub2.com", "sub3.sub2.com"]) +def test_response(question: str): + response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + return + + +@pytest.mark.parametrize( + "question", ["miss.s.com", "miss.sub1.com", "miss.sub2.com", "miss.sub3.sub2.com"] +) +def test_nxdomain(question: str): + response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 0 + assert response.header.rcode == dnslib.RCODE.NXDOMAIN + return + + +## Hooks +## ----------------------------------------------------------------------------- +@pytest.mark.parametrize( + "question,hook_counts", + [ + ("s.com", [1, 5, 5]), + ("sub1.com", [1, 5, 5, 1, 5, 5]), + ("sub2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5]), + ("sub3.sub2.com", [1, 5, 5, 0, 0, 0, 1, 5, 5, 1, 5, 5]), + ], +) +def test_hooks(question: str, hook_counts: List[int]): + ## Setup + # fill unset hook_counts + hook_counts += [0] * (12 - len(hook_counts)) + assert len(hook_counts) == 12 + # reset hooks + reset_all_hooks() + + ## Test + for _ in range(5): + response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + assert len(response.rr) == 1 + assert response.rr[0].rtype == 1 + assert response.rr[0].rname == question + + check_hook_call_count(nameserver, *hook_counts[:3]) + check_hook_call_count(subserver_1, *hook_counts[3:6]) + check_hook_call_count(subserver_2, *hook_counts[6:9]) + check_hook_call_count(subserver_3, *hook_counts[9:]) + return From afe50807dfc5b55d2268919d867d7c777a11661c Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 11:44:50 +1100 Subject: [PATCH 02/16] Drop python 3.7 support, move to uv, add GHA --- .dockerignore | 1 - .github/workflows/test-suite.yml | 54 +++++++++ dev.sh | 200 ++++--------------------------- docker-compose.yml | 42 ------- lib/python/build.Dockerfile | 24 ---- lib/python/build.sh | 19 +-- lib/python/common.Dockerfile | 26 ---- lib/python/install_pypy.sh | 39 ------ lib/python/tox.Dockerfile | 50 -------- pyproject.toml | 31 +++-- src/nserver/_version.py | 1 + tox.ini | 38 +++++- 12 files changed, 124 insertions(+), 401 deletions(-) delete mode 100644 .dockerignore create mode 100644 .github/workflows/test-suite.yml delete mode 100644 docker-compose.yml delete mode 100644 lib/python/build.Dockerfile delete mode 100644 lib/python/common.Dockerfile delete mode 100755 lib/python/install_pypy.sh delete mode 100644 lib/python/tox.Dockerfile diff --git a/.dockerignore b/.dockerignore deleted file mode 100644 index 2d2ecd6..0000000 --- a/.dockerignore +++ /dev/null @@ -1 +0,0 @@ -.git/ diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml new file mode 100644 index 0000000..c7508e0 --- /dev/null +++ b/.github/workflows/test-suite.yml @@ -0,0 +1,54 @@ +name: Test NServer + +on: + push: + branches: + - main + + pull_request: + branches: + - main + +jobs: + lint: + name: "Python Lint" + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: astral-sh/setup-uv@v3 + + - name: Lint with tox + run: uvx tox --with tox-uv -e lint + + test: + name: "Python Test ${{matrix.python-version}} ${{ matrix.os }}" + needs: [lint] + runs-on: "${{ matrix.os }}" + strategy: + fail-fast: false # allow tests to run on all platforms + matrix: + python-version: + - "pypy-3.8" + - "pypy-3.9" + - "pypy-3.10" + - "3.8" + - "3.9" + - "3.10" + - "3.11" + - "3.12" + - "3.13" + os: + - ubuntu-latest + - windows-latest + - macos-latest + + steps: + - uses: actions/checkout@v4 + - uses: astral-sh/setup-uv@v3 + + - name: Set up Python ${{ matrix.python-version }} + run: uv python install ${{ matrix.python-version }} + + - name: Test with tox + run: uvx --with tox-uv tox-gh-actions tox diff --git a/dev.sh b/dev.sh index 0eb753f..d95f317 100755 --- a/dev.sh +++ b/dev.sh @@ -15,13 +15,15 @@ set -e # Bail at the first sign of trouble # Notation Reference: https://unix.stackexchange.com/questions/122845/using-a-b-for-variable-assignment-in-scripts#comment685330_122848 : ${DEBUG:=0} : ${CI:=0} # Flag for if we are in CI - default to not. -: ${SKIP_BUILD:=0} # Allow some commands to forcibly skip compose-build -: ${PORT:=8000} # allows for some commands to change the port if ! command -v toml &> /dev/null; then pip install --user toml-cli fi +if ! command -v uv &> /dev/null; then + pip install --user uv +fi + ### CONTANTS ### ============================================================================ SOURCE_UID=$(id -u) @@ -45,8 +47,7 @@ PACKAGE_VERSION=$(toml get --toml-path pyproject.toml project.version) # You may want to customise these for your project # TODO: this potentially should be moved to manifest.env so that projects can easily # customise the main dev.sh -SOURCE_FILES="src tests" -PYTHON_MIN_VERSION="py37" +PYTHON_MIN_VERSION="py38" ## Build related ## ----------------------------------------------------------------------------- @@ -117,77 +118,6 @@ cp .tmp/env .env ### FUNCTIONS ### ============================================================================ -## Docker Functions -## ----------------------------------------------------------------------------- -function compose_build { - heading2 "🐋 Building $1" - if [[ "$CI" = 1 ]]; then - docker compose build --progress plain $1 - - elif [[ "$DEBUG" -gt 0 ]]; then - docker compose build --progress plain $1 - - else - docker compose build $1 - fi - echo -} - -function compose_run { - heading2 "🐋 running $@" - docker compose -f docker-compose.yml run --rm "$@" - echo -} - -function docker_clean { - heading2 "🐋 Removing $PACKAGE_NAME images" - IMAGES=$(docker images --filter "reference=${PACKAGE_NAME}-asdf*" | tail -n +2) - COUNT_IMAGES=$(echo -n "$IMAGES" | wc -l) - if [[ "$DEBUG" -gt 0 ]]; then - echo "IMAGES=$IMAGES" - echo "COUNT_IMAGES=$COUNT_IMAGES" - fi - - if [[ "$COUNT_IMAGES" -gt 0 ]]; then - docker images | grep "$PACKAGE_NAME" | awk '{OFS=":"} {print $1, $2}' | xargs -t docker rmi - fi -} - - -function docker_clean_unused { - docker images --filter "reference=${PACKAGE_NAME}-*" -a | \ - tail -n +2 | \ - grep -v "$GIT_COMMIT" | \ - awk '{OFS=":"} {print $1, $2}' | \ - xargs -t docker rmi -} - -function docker_autoclean { - if [[ "$CI" = 0 ]]; then - if [[ "$DEBUG" -gt 0 ]]; then - heading2 "🐋 determining if need to clean" - fi - - IMAGES=$( - docker images --filter "reference=${PACKAGE_NAME}-*" -a |\ - tail -n +2 |\ - grep -v "$GIT_COMMIT" ;\ - /bin/true - ) - COUNT_IMAGES=$(echo "$IMAGES" | wc -l) - - if [[ "$DEBUG" -gt 0 ]]; then - echo "IMAGES=${IMAGES}" - echo "COUNT_IMAGES=${COUNT_IMAGES}" - fi - - if [[ $COUNT_IMAGES -gt $AUTOCLEAN_LIMIT ]]; then - heading2 "Removing unused ${PACKAGE_NAME} images 🐋" - docker_clean_unused - fi - fi -} - ## Utility ## ----------------------------------------------------------------------------- function heading { @@ -228,32 +158,6 @@ function check_pyproject_toml { ## Command Functions ## ----------------------------------------------------------------------------- -function command_build { - if [[ -z "$1" || "$1" == "dist" ]]; then - BUILD_DIR="dist" - elif [[ "$1" == "tmp" ]]; then - BUILD_DIR=".tmp/dist" - else - return 1 - fi - - # TODO: unstashed changed guard - - if [[ ! -d "$BUILD_DIR" ]]; then - heading "setup 📜" - mkdir $BUILD_DIR - fi - - echo "BUILD_DIR=${BUILD_DIR}" >> .env - echo "BUILD_DIR=${BUILD_DIR}" >> .tmp/env - - heading "build 🐍" - # Note: we always run compose_build because we copy the package source code to - # the container so we can modify it without affecting local source code. - compose_build python-build - compose_run python-build -} - function display_usage { echo "dev.sh - development utility" @@ -306,70 +210,30 @@ case $1 in echo "ERROR! Do not run format in CI!" exit 250 fi - heading "black 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - - compose_run python-common \ - black --line-length 100 --target-version ${PYTHON_MIN_VERSION} $SOURCE_FILES + heading "tox 🐍 - format" + uvx tox -e format || true ;; "lint") - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - - if [[ "$DEBUG" -gt 0 ]]; then - heading2 "🤔 Debugging" - compose_run python-common ls -lah - compose_run python-common pip list - fi - - heading "validate-pyproject 🐍" - compose_run python-common validate-pyproject pyproject.toml - - heading "black - check only 🐍" - compose_run python-common \ - black --line-length 100 --target-version ${PYTHON_MIN_VERSION} --check --diff $SOURCE_FILES - - heading "pylint 🐍" - compose_run python-common pylint -j 4 --output-format=colorized $SOURCE_FILES - - heading "mypy 🐍" - compose_run python-common mypy $SOURCE_FILES - + heading "tox 🐍 - lint" + uvx tox -e lint || true ;; "test") - command_build tmp - - heading "tox 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-tox - fi - compose_run python-tox tox -e ${PYTHON_MIN_VERSION} || true - - rm -rf .tmp/dist/* + heading "tox 🐍 - single" + uvx tox -e py312 || true ;; "test-full") - command_build tmp - - heading "tox 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-tox - fi - compose_run python-tox tox || true - - rm -rf .tmp/dist/* + heading "tox 🐍 - all" + uvx tox || true ;; "build") - command_build dist + source ./lib/python/build.sh ;; @@ -408,44 +272,20 @@ print('Your package is already imported 🎉\nPress ctrl+d to exit') EOF fi - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run python-common bpython --config bpython.ini -i .tmp/repl.py - - ;; - - "run") - heading "Running 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run python-common "${@:2}" + uv run python -i .tmp/repl.py ;; "docs") heading "Preview Docs 🐍" - if [[ "$SKIP_BUILD" = 0 ]]; then - compose_build python-common - fi - compose_run -p 127.0.0.1:${PORT}:8080 python-common mkdocs serve -a 0.0.0.0:8080 -w docs + uv run --extra dev mkdocs serve -w docs ;; "build-docs") heading "Building Docs 🐍" - if [[ -z "$VIRTUAL_ENV" ]]; then - echo "This command should be run in a virtual environment to avoid poluting" - exit 1 - fi - - if [[ -z $(pip3 list | grep mike) ]]; then - pip install -e.[docs] - fi - - mike deploy "$PACKAGE_VERSION" "latest" \ + uv run --extra dev mike deploy "$PACKAGE_VERSION" "latest" \ --update-aliases \ --prop-set-string "git_branch=${GIT_BRANCH}" \ --prop-set-string "git_commit=${GIT_COMMIT}" \ @@ -462,7 +302,6 @@ EOF "clean") heading "Cleaning 📜" - docker_clean echo "🐍 pyclean" if ! command -v pyclean &> /dev/null; then @@ -471,6 +310,9 @@ EOF pyclean src pyclean tests + echo "🐍 clear .tox" + rm -rf .tox + echo "🐍 remove build artifacts" rm -rf build dist "src/${PACKAGE_PYTHON_NAME}.egg-info" @@ -523,5 +365,3 @@ EOF ;; esac - -docker_autoclean diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 86a3c9e..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,42 +0,0 @@ -version: "3.1" -services: - python-common: &pythonBase - image: "${PACKAGE_NAME}-python-general:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/common.Dockerfile - args: &pythonBaseBuildArgs - - "SOURCE_UID=${SOURCE_UID}" - - "SOURCE_GID=${SOURCE_GID}" - - "SOURCE_UID_GID=${SOURCE_UID_GID}" - user: devuser - working_dir: /code - env_file: - - .tmp/env - environment: - - "PATH=/home/devuser/.local/bin:/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games" - volumes: - - .:/code - - python-build: - <<: *pythonBase - image: "${PACKAGE_NAME}-python-build:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/build.Dockerfile - args: *pythonBaseBuildArgs - command: "/code/lib/python/build.sh" - volumes: - - ./${BUILD_DIR}:/code/dist - - python-tox: - <<: *pythonBase - image: "${PACKAGE_NAME}-python-tox:${GIT_COMMIT}" - build: - context: . - dockerfile: lib/python/tox.Dockerfile - args: *pythonBaseBuildArgs - volumes: - - ./${BUILD_DIR}:/code/dist - - ./tests:/code/tests - - ./tox.ini:/code/tox.ini diff --git a/lib/python/build.Dockerfile b/lib/python/build.Dockerfile deleted file mode 100644 index c25011a..0000000 --- a/lib/python/build.Dockerfile +++ /dev/null @@ -1,24 +0,0 @@ -FROM python:3.7 - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN mkdir -p /code/src \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shel /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su - devuser -c "pip install --user --upgrade pip" - -## ^^ copied from common.Dockerfile - try to keep in sync fo caching - -# Base stuff -ADD . /code - -RUN chown -R ${SOURCE_UID_GID} /code # needed twice because added files - -RUN ls -lah /code - -RUN su - devuser -c "cd /code && pip install --user build" - -CMD echo "docker-compose build python-build complete 🎉" diff --git a/lib/python/build.sh b/lib/python/build.sh index c4d7bcc..48c0405 100755 --- a/lib/python/build.sh +++ b/lib/python/build.sh @@ -47,22 +47,7 @@ replace_version_var BUILD_DATETIME "${BUILD_DATETIME}" 0 head -n 22 "src/${PACKAGE_PYTHON_NAME}/_version.py" | tail -n 7 -if [ "$PYTHON_PACKAGE_REPOSITORY" == "testpypi" ]; then - echo "MODIFYING PACKAGE_NAME" - # Replace name suitable for test.pypi.org - # https://packaging.python.org/tutorials/packaging-projects/#creating-setup-py - sed -i "s/^PACKAGE_NAME = .*/PACKAGE_NAME = \"${PACKAGE_NAME}-${TESTPYPI_USERNAME}\"/" setup.py - grep "^PACKAGE_NAME = " setup.py - - mv "src/${PACKAGE_PYTHON_NAME}" "src/${PACKAGE_PYTHON_NAME}_$(echo -n $TESTPYPI_USERNAME | tr '-' '_')" -fi - -if [[ "$GIT_BRANCH" != "master" && "$GIT_BRANCH" != "main" ]]; then - sed -i "s/^PACKAGE_VERSION = .*/PACKAGE_VERSION = \"${BUILD_VERSION}\"/" setup.py - grep "^PACKAGE_VERSION = " setup.py -fi - ## Build ## ----------------------------------------------------------------------------- -#python3 setup.py bdist_wheel -python3 -m build --wheel +uv build +git restore src/${PACKAGE_PYTHON_NAME}/_version.py diff --git a/lib/python/common.Dockerfile b/lib/python/common.Dockerfile deleted file mode 100644 index 668c098..0000000 --- a/lib/python/common.Dockerfile +++ /dev/null @@ -1,26 +0,0 @@ -# syntax = docker/dockerfile:1.2 -FROM python:3.7 - - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN apt update && apt install -y \ - less - -RUN mkdir -p /code/src \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shell /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su -l devuser -c "pip install --user --upgrade pip" - -ADD pyproject.toml /code -RUN chown -R ${SOURCE_UID_GID} /code # needed twice because added files - -RUN ls -lah /code /home /home/devuser /home/devuser/.cache /home/devuser/.cache/pip - -RUN --mount=type=cache,target=/home/devuser/.cache,uid=1000,gid=1000 \ - su -l devuser -c "cd /code && pip install --user -e .[dev,docs]" - -CMD echo "docker-compose build python-common complete 🎉" diff --git a/lib/python/install_pypy.sh b/lib/python/install_pypy.sh deleted file mode 100755 index c3547f0..0000000 --- a/lib/python/install_pypy.sh +++ /dev/null @@ -1,39 +0,0 @@ -#!/bin/bash -set -e - -PYPY_VERSION="7.3.9" -PYTHON_VERSIONS="3.7 3.8 3.9" - -# Note: pypy-7.3.9 is last version to support python3.7 - -if [ ! -d /tmp/pypy ]; then - mkdir /tmp/pypy -fi - -cd /tmp/pypy - -for PYTHON_VERSION in $PYTHON_VERSIONS; do - FULLNAME="pypy${PYTHON_VERSION}-v${PYPY_VERSION}-linux64" - FILENAME="${FULLNAME}.tar.bz2" - - if [ ! -f "${FILENAME}" ]; then - # not cached - fetch - echo "Fetching ${FILENAME}" - wget -q "https://downloads.python.org/pypy/${FILENAME}" - fi - - echo "Extracting ${FILENAME} to /opt/${FULLNAME}" - tar xf ${FILENAME} --directory=/opt - - echo "Removing temp file" - rm -f ${FILENAME} - - echo "sanity check" - ls /opt - - echo "Linking ${FULLNAME}/bin/pypy${PYTHON_VERSION} to /usr/bin" - ln -s "/opt/${FULLNAME}/bin/pypy${PYTHON_VERSION}" /usr/bin/ - - echo "" - -done diff --git a/lib/python/tox.Dockerfile b/lib/python/tox.Dockerfile deleted file mode 100644 index e8ad7c9..0000000 --- a/lib/python/tox.Dockerfile +++ /dev/null @@ -1,50 +0,0 @@ -FROM ubuntu:20.04 - -# We use deadsnakes ppa to install -# https://launchpad.net/~deadsnakes/+archive/ubuntu/ppa -# -# As noted in the readme, 22.04 supports only 3.7+, so use 20.04 to support some older versions -# This also means we don't install 3.8 as it is already provided - -# TZ https://serverfault.com/a/1016972 -ARG DEBIAN_FRONTEND=noninteractive -ENV TZ=Etc/UTC - -RUN --mount=target=/var/lib/apt/lists,type=cache,sharing=locked \ - --mount=target=/var/cache/apt,type=cache,sharing=locked \ - rm -f /etc/apt/apt.conf.d/docker-clean \ - && apt update \ - && apt upgrade --yes \ - && apt install --yes software-properties-common wget python3-pip\ - && add-apt-repository ppa:deadsnakes/ppa \ - && apt update --yes - -RUN --mount=target=/var/lib/apt/lists,type=cache,sharing=locked \ - --mount=target=/var/cache/apt,type=cache,sharing=locked \ - apt install --yes \ - python3.6 python3.6-dev python3.6-distutils \ - python3.7 python3.7-dev python3.7-distutils \ - python3.9 python3.9-dev python3.9-distutils \ - python3.10 python3.10-dev python3.10-distutils \ - python3.11 python3.11-dev python3.11-distutils \ - python3.12 python3.12-dev python3.12-distutils - -## pypy -ADD lib/python/install_pypy.sh /tmp -RUN --mount=target=/tmp/pypy,type=cache,sharing=locked \ - /tmp/install_pypy.sh - - -ARG SOURCE_UID -ARG SOURCE_GID -ARG SOURCE_UID_GID - -RUN mkdir -p /code/dist /code/tests \ - && groupadd --gid ${SOURCE_GID} devuser \ - && useradd --uid ${SOURCE_GID} -g devuser --create-home --shell /bin/bash devuser \ - && chown -R ${SOURCE_UID_GID} /code \ - && su - devuser -c "pip install --user --upgrade pip" - -RUN su - devuser -c "pip install --user tox" - -CMD echo "docker-compose build python-tox complete 🎉" diff --git a/pyproject.toml b/pyproject.toml index 9ff5a2d..62c1fdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,16 +4,17 @@ build-backend = "setuptools.build_meta" [project] name = "nserver" -version = "2.0.1.rc1" +version = "3.0.0.dev1" description = "DNS Name Server Framework" authors = [ {name = "Nicholas Hairs", email = "info+nserver@nicholashairs.com"}, ] # Dependency Information -requires-python = ">=3.7" +requires-python = ">=3.8" dependencies = [ "dnslib", + "pillar~=0.3", "tldextract", ] @@ -23,12 +24,12 @@ license = {text = "MIT"} classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "License :: OSI Approved :: MIT License", "Typing :: Typed", "Topic :: Internet", @@ -36,17 +37,13 @@ classifiers = [ ] [project.urls] -homepage = "https://nhairs.github.io/nserver/latest/" -github = "https://github.com/nhairs/nserver" +HomePage = "https://nhairs.github.io/nserver" +GitHub = "https://github.com/nhairs/nserver" [project.optional-dependencies] -build = [ - "setuptools", - "wheel", -] - dev = [ - ### dev.sh dependencies + "tox", + "tox-uv", ## Formatting / Linting "validate-pyproject[all]", "black", @@ -54,11 +51,10 @@ dev = [ "mypy", ## Testing "pytest", - ## REPL - "bpython", -] - -docs = [ + ## Build + "setuptools", + "wheel", + ## Docs "black", "mkdocs", "mkdocs-material>=8.5", @@ -72,3 +68,6 @@ docs = [ [tool.setuptools.package-data] nserver = ["py.typed"] + +[tool.black] +line-length = 100 diff --git a/src/nserver/_version.py b/src/nserver/_version.py index 24869af..722c41c 100644 --- a/src/nserver/_version.py +++ b/src/nserver/_version.py @@ -1,4 +1,5 @@ """Version information for this package.""" + ### IMPORTS ### ============================================================================ ## Standard Library diff --git a/tox.ini b/tox.ini index 5e7f556..21479b3 100644 --- a/tox.ini +++ b/tox.ini @@ -1,10 +1,36 @@ [tox] -envlist = py37,py38,py39,py310,py311,py312,pypy37,pypy38,pypy39 +requires = tox>=3,tox-uv +envlist = pypy{38,39,310}, py{38,39,310,311,312,313} + +[gh-actions] +python = + pypy-3.8: pypy38 + pypy-3.9: pypy39 + pypy-3.10: pypy310 + 3.8: py38 + 3.9: py39 + 3.10: py310 + 3.11: py311 + 3.12: py312 + 3.13: py313 [testenv] -package = external -deps = pytest -commands = {posargs:pytest -ra tests} +description = run unit tests +extras = dev +commands = + pytest tests + +[testenv:format] +description = run formatters +extras = dev +commands = + black src tests -[testenv:.pkg_external] -package_glob = /code/dist/* +[testenv:lint] +description = run linters +extras = dev +commands = + validate-pyproject pyproject.toml + black --check --diff src tests + pylint src + mypy src tests From 240d965ea6d2c5f017937bdf3612c602d805fd76 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 11:50:18 +1100 Subject: [PATCH 03/16] Fix syntax error --- .github/workflows/test-suite.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index c7508e0..a1e050c 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -19,7 +19,7 @@ jobs: - uses: astral-sh/setup-uv@v3 - name: Lint with tox - run: uvx tox --with tox-uv -e lint + run: uvx --with tox-uv tox -e lint test: name: "Python Test ${{matrix.python-version}} ${{ matrix.os }}" From 19b19bb37d9e264179d7f5c3a542ff71d2fd0595 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 11:53:30 +1100 Subject: [PATCH 04/16] Support new pylint checks --- pylintrc | 3 +++ src/nserver/records.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pylintrc b/pylintrc index 776bfe6..76f43a7 100644 --- a/pylintrc +++ b/pylintrc @@ -479,6 +479,9 @@ valid-metaclass-classmethod-first-arg=cls # Maximum number of arguments for function / method. max-args=10 +# Max number of positional arguments for a function / method +max-positional-arguments=8 + # Maximum number of attributes for a class (see R0902). max-attributes=15 diff --git a/src/nserver/records.py b/src/nserver/records.py index 9915f67..71cef29 100644 --- a/src/nserver/records.py +++ b/src/nserver/records.py @@ -222,7 +222,7 @@ class SOA(RecordBase): - https://en.wikipedia.org/wiki/SOA_record """ - def __init__( # pylint: disable=too-many-arguments + def __init__( # pylint: disable=too-many-arguments,too-many-positional-arguments self, zone_name: str, primary_name_server: str, From a75b3f015b0fead8754016ba5e660ffaf2cf98a3 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 11:56:36 +1100 Subject: [PATCH 05/16] Fix GHA error --- .github/workflows/test-suite.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index a1e050c..b48c437 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -51,4 +51,4 @@ jobs: run: uv python install ${{ matrix.python-version }} - name: Test with tox - run: uvx --with tox-uv tox-gh-actions tox + run: uvx --with tox-uv,tox-gh-actions tox From 3f5893923e59dab1f029716ec66ab8c2d332b5f5 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 12:06:17 +1100 Subject: [PATCH 06/16] fix gha --- .github/workflows/test-suite.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index b48c437..e4f64d8 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -29,9 +29,9 @@ jobs: fail-fast: false # allow tests to run on all platforms matrix: python-version: - - "pypy-3.8" - - "pypy-3.9" - - "pypy-3.10" + - "pypy@3.8" + - "pypy@3.9" + - "pypy@3.10" - "3.8" - "3.9" - "3.10" From 9e72a27eaa92479a68aff35035be8a89ed8000bc Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 12:16:34 +1100 Subject: [PATCH 07/16] ignore pypy in test matrix --- .github/workflows/test-suite.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index e4f64d8..9f84b16 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -29,9 +29,10 @@ jobs: fail-fast: false # allow tests to run on all platforms matrix: python-version: - - "pypy@3.8" - - "pypy@3.9" - - "pypy@3.10" + # Tox is not picking up pypi - ignore for now + #- "pypy@3.8" + #- "pypy@3.9" + #- "pypy@3.10" - "3.8" - "3.9" - "3.10" From 12a0659eb0adb345da68a93256fe9a3c11cb64e2 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 12:21:05 +1100 Subject: [PATCH 08/16] Don't use tox-gh-actions --- .github/workflows/test-suite.yml | 16 +--------------- tox.ini | 12 ------------ 2 files changed, 1 insertion(+), 27 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 9f84b16..b5192d3 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -28,17 +28,6 @@ jobs: strategy: fail-fast: false # allow tests to run on all platforms matrix: - python-version: - # Tox is not picking up pypi - ignore for now - #- "pypy@3.8" - #- "pypy@3.9" - #- "pypy@3.10" - - "3.8" - - "3.9" - - "3.10" - - "3.11" - - "3.12" - - "3.13" os: - ubuntu-latest - windows-latest @@ -48,8 +37,5 @@ jobs: - uses: actions/checkout@v4 - uses: astral-sh/setup-uv@v3 - - name: Set up Python ${{ matrix.python-version }} - run: uv python install ${{ matrix.python-version }} - - name: Test with tox - run: uvx --with tox-uv,tox-gh-actions tox + run: uvx --with tox-uv tox diff --git a/tox.ini b/tox.ini index 21479b3..0aaa4fb 100644 --- a/tox.ini +++ b/tox.ini @@ -2,18 +2,6 @@ requires = tox>=3,tox-uv envlist = pypy{38,39,310}, py{38,39,310,311,312,313} -[gh-actions] -python = - pypy-3.8: pypy38 - pypy-3.9: pypy39 - pypy-3.10: pypy310 - 3.8: py38 - 3.9: py39 - 3.10: py310 - 3.11: py311 - 3.12: py312 - 3.13: py313 - [testenv] description = run unit tests extras = dev From d16e29141348a1bb86536501768f07d83101338b Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 12:21:44 +1100 Subject: [PATCH 09/16] Fix GHA error --- .github/workflows/test-suite.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index b5192d3..64af757 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -22,7 +22,7 @@ jobs: run: uvx --with tox-uv tox -e lint test: - name: "Python Test ${{matrix.python-version}} ${{ matrix.os }}" + name: "Python Test ${{ matrix.os }}" needs: [lint] runs-on: "${{ matrix.os }}" strategy: From c304634bad11a2be66508951fd8b7c1024e2326b Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sat, 2 Nov 2024 12:51:16 +1100 Subject: [PATCH 10/16] Ignore uv.lock --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 0f424ef..a8247d8 100644 --- a/.gitignore +++ b/.gitignore @@ -194,3 +194,4 @@ dmypy.json ### PROJECT ### ============================================================================ # Project specific stuff goes here +uv.lock From b196aa6b235d9213846cefd67eb1db688e9ba477 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 10 Nov 2024 13:10:09 +1100 Subject: [PATCH 11/16] Refactor middleware and servers --- pylintrc | 2 +- src/nserver/__init__.py | 3 +- src/nserver/application.py | 109 +++++++ src/nserver/cli.py | 132 ++++++++ src/nserver/exceptions.py | 8 +- src/nserver/middleware.py | 428 ++++++++++--------------- src/nserver/models.py | 11 +- src/nserver/records.py | 11 +- src/nserver/rules.py | 53 +++- src/nserver/server.py | 620 ++++++++++++++----------------------- src/nserver/settings.py | 35 --- src/nserver/transport.py | 57 ++-- src/nserver/util.py | 3 + tests/test_blueprint.py | 9 +- tests/test_server.py | 22 +- tests/test_subserver.py | 48 ++- 16 files changed, 768 insertions(+), 783 deletions(-) create mode 100644 src/nserver/application.py create mode 100644 src/nserver/cli.py delete mode 100644 src/nserver/settings.py diff --git a/pylintrc b/pylintrc index 76f43a7..2dd7180 100644 --- a/pylintrc +++ b/pylintrc @@ -456,7 +456,7 @@ preferred-modules= # List of method names used to declare (i.e. assign) instance attributes. defining-attr-methods=__init__, __new__, - setUp, + setup, __post_init__ # List of member names, which should be excluded from the protected access diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index d4fef35..0758315 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,5 +1,4 @@ from .models import Query, Response from .rules import ALL_QTYPES, StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA -from .server import NameServer, SubServer, Blueprint -from .settings import Settings +from .server import NameServer, RawNameServer, Blueprint diff --git a/src/nserver/application.py b/src/nserver/application.py new file mode 100644 index 0000000..195bc32 --- /dev/null +++ b/src/nserver/application.py @@ -0,0 +1,109 @@ +### IMPORTS +### ============================================================================ +## Future +from __future__ import annotations + +## Standard Library + +## Installed +from pillar.logging import LoggingMixin + +## Application +from .exceptions import InvalidMessageError +from .server import NameServer, RawNameServer +from .transport import TransportBase + + +### CLASSES +### ============================================================================ +class BaseApplication(LoggingMixin): + """Base class for all application classes. + + New in `3.0`. + """ + + def __init__(self, server: NameServer | RawNameServer) -> None: + if isinstance(server, NameServer): + server = RawNameServer(server) + self.server: RawNameServer = server + self.logger = self.get_logger() + return + + def run(self) -> int | None: + """Run this application. + + Child classes must override this method. + + Returns: + Integer status code to be returned. `None` will be treated as `0`. + """ + raise NotImplementedError() + + +class DirectApplication(BaseApplication): + """Application that directly runs the server. + + New in `3.0`. + """ + + MAX_ERRORS: int = 10 + + exit_code: int + + def __init__(self, server: NameServer | RawNameServer, transport: TransportBase) -> None: + super().__init__(server) + self.transport = transport + self.exit_code = 0 + self.shutdown_server = False + return + + def run(self) -> int: + """Start running the server + + Returns: + `exit_code`, `0` if exited normally + """ + # Start Server + # TODO: Do we want to recreate the transport instance or do we assume that + # transport.shutdown_server puts it back into a ready state? + # We could make this configurable? :thonking: + + self.info(f"Starting {self.transport}") + try: + self.transport.start_server() + except Exception as e: # pylint: disable=broad-except + self.critical(f"Failed to start server. {e}", exc_info=e) + self.exit_code = 1 + return self.exit_code + + # Process Requests + error_count = 0 + while True: + if self.shutdown_server: + break + + try: + message = self.transport.receive_message() + message.response = self.server.process_request(message.message) + self.transport.send_message_response(message) + + except InvalidMessageError as e: + self.warning(f"{e}") + + except Exception as e: # pylint: disable=broad-except + self.error(f"Uncaught error occured. {e}", exc_info=e) + error_count += 1 + if self.MAX_ERRORS and error_count >= self.MAX_ERRORS: + self.critical(f"Max errors hit ({error_count})") + self.shutdown_server = True + self.exit_code = 1 + + except KeyboardInterrupt: + self.info("KeyboardInterrupt received.") + self.shutdown_server = True + + # Stop Server + self.info("Shutting down server") + self.transport.stop_server() + + return self.exit_code diff --git a/src/nserver/cli.py b/src/nserver/cli.py new file mode 100644 index 0000000..96be94b --- /dev/null +++ b/src/nserver/cli.py @@ -0,0 +1,132 @@ +### IMPORTS +### ============================================================================ +## Future +from __future__ import annotations + +## Standard Library +import argparse +import importlib + +## Installed +import pillar.application + +## Application +from . import transport +from . import _version + +from .application import BaseApplication, DirectApplication +from .server import NameServer, RawNameServer + + +### CLASSES +### ============================================================================ +class CliApplication(pillar.application.Application): + """NServer CLI tool for running servers""" + + application_name = "nserver" + name = "nserver" + version = _version.VERSION_INFO_FULL + epilog = "For full information including licence see https://github.com/nhairs/nserver" + + config_args_enabled = False + + def get_argument_parser(self) -> argparse.ArgumentParser: + parser = super().get_argument_parser() + + ## Server + ## --------------------------------------------------------------------- + parser.add_argument( + "--server", + action="store", + help=( + "Import path of server / factory to run in the form of " + "package.module.path:attribute" + ), + ) + + ## Transport + ## --------------------------------------------------------------------- + parser.add_argument( + "--host", + action="store", + default="localhost", + help="Host (IP) to bind to. Defaults to localhost.", + ) + + parser.add_argument( + "--port", + action="store", + default=5300, + type=int, + help="Port to bind to. Defaults to 5300.", + ) + + transport_group = parser.add_mutually_exclusive_group() + transport_group.add_argument( + "--udp", + action="store_const", + const=transport.UDPv4Transport, + dest="transport", + help="Use UDPv4 socket for transport. (default)", + ) + transport_group.add_argument( + "--udp6", + action="store_const", + const=transport.UDPv6Transport, + dest="transport", + help="Use UDPv6 socket for transport.", + ) + transport_group.add_argument( + "--tcp", + action="store_const", + const=transport.TCPv4Transport, + dest="transport", + help="Use TCPv4 socket for transport.", + ) + + parser.set_defaults(transport=transport.UDPv4Transport) + return parser + + def setup(self, *args, **kwargs) -> None: + super().setup(*args, **kwargs) + + self.server = self.get_server() + self.application = self.get_application() + return + + def main(self) -> int | None: + return self.application.run() + + def get_server(self) -> NameServer | RawNameServer: + """Factory for getting the server to run based on current settings""" + module_path, attribute_path = self.args.server.split(":") + obj: object = importlib.import_module(module_path) + + for attribute_name in attribute_path.split("."): + obj = getattr(obj, attribute_name) + + if isinstance(obj, (NameServer, RawNameServer)): + return obj + + # Assume callable (will throw error if not) + server = obj() # type: ignore[operator] + + if isinstance(server, (NameServer, RawNameServer)): + return server + + raise TypeError(f"Imported factory ({obj}) did not return a server ({server})") + + def get_application(self) -> BaseApplication: + """Factory for getting the application based on current settings""" + application = DirectApplication( + self.server, + self.args.transport(self.args.host, self.args.port), + ) + return application + + +### MAIN +### ============================================================================ +if __name__ == "__main__": + app = CliApplication() + app.run() diff --git a/src/nserver/exceptions.py b/src/nserver/exceptions.py index c72946d..89cc2ca 100644 --- a/src/nserver/exceptions.py +++ b/src/nserver/exceptions.py @@ -1,11 +1,11 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import base64 -# Note: Union can only be replaced with `X | Y` in 3.10+ -from typing import Tuple, Union - ## Installed ## Application @@ -17,7 +17,7 @@ class InvalidMessageError(ValueError): """An invalid DNS message""" def __init__( - self, error: Exception, raw_data: bytes, remote_address: Union[str, Tuple[str, int]] + self, error: Exception, raw_data: bytes, remote_address: str | tuple[str, int] ) -> None: """ Args: diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 2662bd7..66d7cab 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -1,163 +1,212 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import inspect import threading -from typing import Callable, Dict, List, Type, Optional +from typing import TYPE_CHECKING, Callable, Generic, TypeVar, TypeAlias ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application from .models import Query, Response -from .records import RecordBase -from .rules import RuleBase, RuleResult - +from .rules import coerce_to_response, RuleResult ### CONSTANTS ### ============================================================================ +# pylint: disable=invalid-name +T_request = TypeVar("T_request") +T_response = TypeVar("T_response") +# pylint: enable=invalid-name + ## Query Middleware -QueryMiddlewareCallable = Callable[[Query], Response] +## ----------------------------------------------------------------------------- +QueryCallable: TypeAlias = Callable[[Query], Response] """Type alias for functions that can be used with `QueryMiddleware.next_function`""" -ExceptionHandler = Callable[[Query, Exception], Response] +QueryExceptionHandler: TypeAlias = Callable[[Query, Exception], Response] """Type alias for `ExceptionHandlerMiddleware` exception handler functions""" # Hooks -BeforeFirstQueryHook = Callable[[], None] +BeforeFirstQueryHook: TypeAlias = Callable[[], None] """Type alias for `HookMiddleware.before_first_query` functions.""" -BeforeQueryHook = Callable[[Query], RuleResult] +BeforeQueryHook: TypeAlias = Callable[[Query], RuleResult] """Type alias for `HookMiddleware.before_query` functions.""" -AfterQueryHook = Callable[[Response], Response] +AfterQueryHook: TypeAlias = Callable[[Response], Response] """Type alias for `HookMiddleware.after_query` functions.""" ## RawRecordMiddleware -RawRecordMiddlewareCallable = Callable[[dnslib.DNSRecord], dnslib.DNSRecord] -"""Type alias for functions that can be used with `RawRecordMiddleware.next_function`""" - -RawRecordExceptionHandler = Callable[[dnslib.DNSRecord, Exception], dnslib.DNSRecord] -"""Type alias for `RawRecordExceptionHandlerMiddleware` exception handler functions""" - - -### FUNCTIONS -### ============================================================================ -def coerce_to_response(result: RuleResult) -> Response: - """Convert some `RuleResult` to a `Response` - - New in `2.0`. - - Args: - result: the results to convert - - Raises: - TypeError: unsupported result type - """ - if isinstance(result, Response): - return result +## ----------------------------------------------------------------------------- +if TYPE_CHECKING: - if result is None: - return Response() + class RawRecord(dnslib.DNSRecord): + "Dummy class for type checking as dnslib is not typed" - if isinstance(result, RecordBase) and result.__class__ is not RecordBase: - return Response(answers=result) +else: + RawRecord: TypeAlias = dnslib.DNSRecord + """Type alias for raw records to allow easy changing of implementation details""" - if isinstance(result, list) and all(isinstance(item, RecordBase) for item in result): - return Response(answers=result) +RawMiddlewareCallable: TypeAlias = Callable[[RawRecord], RawRecord] +"""Type alias for functions that can be used with `RawRecordMiddleware.next_function`""" - raise TypeError(f"Cannot process result: {result!r}") +RawExceptionHandler: TypeAlias = Callable[[RawRecord, Exception], RawRecord] +"""Type alias for `RawRecordExceptionHandlerMiddleware` exception handler functions""" ### CLASSES ### ============================================================================ -## Request Middleware +## Generic Base Classes ## ----------------------------------------------------------------------------- -class QueryMiddleware: - """Middleware for interacting with `Query` objects +class MiddlewareBase(Generic[T_request, T_response], LoggingMixin): + """Generic base class for middleware classes. - New in `2.0`. + New in `3.0`. """ def __init__(self) -> None: - self.next_function: Optional[QueryMiddlewareCallable] = None + self.next_function: Callable[[T_request], T_response] | None = None + self.logger = self.get_logger() return - def __call__(self, query: Query) -> Response: + def __call__(self, request: T_request) -> T_response: + """Call this middleware + + Args: + request: request to process + + Raises: + RuntimeError: If `next_function` is not set. + """ + if self.next_function is None: - raise RuntimeError("next_function is not set") - return self.process_query(query, self.next_function) + raise RuntimeError("next_function is not set. Need to call register_next_function.") + return self.process_request(request, self.next_function) + + def set_next_function(self, next_function: Callable[[T_request], T_response]) -> None: + """Set the `next_function` of this middleware - def register_next_function(self, next_function: QueryMiddlewareCallable) -> None: - """Set the `next_function` of this middleware""" + Args: + next_function: Callable that this middleware should call next. + """ if self.next_function is not None: - raise RuntimeError("next_function is already set") + raise RuntimeError(f"next_function is already set to {self.next_function}") self.next_function = next_function return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Handle an incoming query. + def process_request( + self, request: T_request, call_next: Callable[[T_request], T_response] + ) -> T_response: + """Process a given request - Child classes should override this function (if they do not this middleware will - simply pass the query onto the next function). - - Args: - query: the incoming query - call_next: the next function in the chain + Child classes should override this method with their own logic. """ - return call_next(query) + return call_next(request) -class ExceptionHandlerMiddleware(QueryMiddleware): - """Middleware for handling exceptions originating from a `QueryMiddleware` stack. - - Allows registering handlers for individual `Exception` types. Only one handler can - exist for a given `Exception` type. - - When an exception is encountered, the middleware will search for the first handler that - matches the class or parent class of the exception in method resolution order. If no handler - is registered will use this classes `self.default_exception_handler`. - - New in `2.0`. +class ExceptionHandlerBase(MiddlewareBase[T_request, T_response]): + """Generic base class for middleware exception handlers Attributes: - exception_handlers: registered exception handlers + handlers: registered exception handlers + + New in `3.0`. """ def __init__( - self, exception_handlers: Optional[Dict[Type[Exception], ExceptionHandler]] = None + self, + handlers: dict[type[Exception], Callable[[T_request, Exception], T_response]] | None = None, ) -> None: - """ - Args: - exception_handlers: exception handlers to assign - """ super().__init__() - self.exception_handlers = exception_handlers if exception_handlers is not None else {} + self.handlers: dict[type[Exception], Callable[[T_request, Exception], T_response]] = ( + handlers if handlers is not None else {} + ) return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Call the next function catching any handling any errors""" + def process_request(self, request, call_next): + """Call the next function handling any exceptions that arise""" try: - response = call_next(query) + response = call_next(request) except Exception as e: # pylint: disable=broad-except - handler = self.get_exception_handler(e) - response = handler(query, e) + handler = self.get_handler(e) + response = handler(request, e) return response - def get_exception_handler(self, exception: Exception) -> ExceptionHandler: - """Get the exception handler for an `Exception`. + def set_handler( + self, + exception_class: type[Exception], + handler: Callable[[T_request, Exception], T_response], + *, + allow_overwrite: bool = False, + ) -> None: + """Add an exception handler for the given exception class + + Args: + exception_class: Exceptions to associate with this handler. + handler: The handler to add. + allow_overwrite: Allow overwriting existing handlers. + + Raises: + ValueError: If a handler already exists for the given exception and + `allow_overwrite` is `False`. + """ + if exception_class in self.handlers and not allow_overwrite: + raise ValueError( + f"Exception handler already exists for {exception_class} and allow_overwrite is False" + ) + self.handlers[exception_class] = handler + return + + def get_handler(self, exception: Exception) -> Callable[[T_request, Exception], T_response]: + """Get the exception handler for the given exception Args: exception: the exception we wish to handle """ for class_ in inspect.getmro(exception.__class__): - if class_ in self.exception_handlers: - return self.exception_handlers[class_] + if class_ in self.handlers: + return self.handlers[class_] # No exception handler found - use default handler - return self.default_exception_handler + return self.default_handler + + @staticmethod + def default_handler(request: T_request, exception: Exception) -> T_response: + """Default exception handler + + Child classes MUST override this method. + """ + raise NotImplementedError("Must overide this method") + + +## Request Middleware +## ----------------------------------------------------------------------------- +class QueryMiddleware(MiddlewareBase[Query, Response]): + """Middleware for interacting with `Query` objects + + New in `3.0`. + """ + + +class QueryExceptionHandlerMiddleware(ExceptionHandlerBase[Query, Response], QueryMiddleware): + """Middleware for handling exceptions originating from a `QueryMiddleware` stack. + + Allows registering handlers for individual `Exception` types. Only one handler can + exist for a given `Exception` type. + + When an exception is encountered, the middleware will search for the first handler that + matches the class or parent class of the exception in method resolution order. If no handler + is registered will use this classes `self.default_exception_handler`. + + New in `3.0`. + """ @staticmethod - def default_exception_handler(query: Query, exception: Exception) -> Response: + def default_handler(request: Query, exception: Exception) -> Response: """The default exception handler""" # pylint: disable=unused-argument return Response(error_code=dnslib.RCODE.SERVFAIL) @@ -182,21 +231,21 @@ class HookMiddleware(QueryMiddleware): hook or from the next function in the middleware chain. They take a `Response` input and must return a `Response`. - New in `2.0`. - Attributes: before_first_query: `before_first_query` hooks before_query: `before_query` hooks after_query: `after_query` hooks before_first_query_run: have we run the `before_first_query` hooks before_first_query_failed: did any `before_first_query` hooks fail + + New in `3.0`. """ def __init__( self, - before_first_query: Optional[List[BeforeFirstQueryHook]] = None, - before_query: Optional[List[BeforeQueryHook]] = None, - after_query: Optional[List[AfterQueryHook]] = None, + before_first_query: list[BeforeFirstQueryHook] | None = None, + before_query: list[BeforeQueryHook] | None = None, + after_query: list[AfterQueryHook] | None = None, ) -> None: """ Args: @@ -205,26 +254,24 @@ def __init__( after_query: initial `after_query` hooks to register """ super().__init__() - self.before_first_query: List[BeforeFirstQueryHook] = ( + self.before_first_query: list[BeforeFirstQueryHook] = ( before_first_query if before_first_query is not None else [] ) - self.before_query: List[BeforeQueryHook] = before_query if before_query is not None else [] - self.after_query: List[AfterQueryHook] = after_query if after_query is not None else [] + self.before_query: list[BeforeQueryHook] = before_query if before_query is not None else [] + self.after_query: list[AfterQueryHook] = after_query if after_query is not None else [] self.before_first_query_run: bool = False self.before_first_query_failed: bool = False self._before_first_query_lock = threading.Lock() return - def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Response: - """Process a query running relevant hooks.""" + def process_request(self, request: Query, call_next: QueryCallable) -> Response: with self._before_first_query_lock: if not self.before_first_query_run: - # self._debug("Running before_first_query") self.before_first_query_run = True try: for before_first_query_hook in self.before_first_query: - # self._vdebug(f"Running before_first_query func: {hook}") + self.vdebug(f"Running before_first_query_hook: {before_first_query_hook}") before_first_query_hook() except Exception: self.before_first_query_failed = True @@ -233,92 +280,34 @@ def process_query(self, query: Query, call_next: QueryMiddlewareCallable) -> Res result: RuleResult for before_query_hook in self.before_query: - result = before_query_hook(query) + self.vdebug(f"Running before_query_hook: {before_query_hook}") + result = before_query_hook(request) if result is not None: - # self._debug(f"Got result from before_hook: {hook}") + self.debug(f"Got result from before_query_hook: {before_query_hook}") break else: # No before query hooks returned a response - keep going - result = call_next(query) + result = call_next(request) response = coerce_to_response(result) for after_query_hook in self.after_query: + self.vdebug(f"Running after_query_hook: {after_query_hook}") response = after_query_hook(response) return response -# Final callable -# .............................................................................. -# This is not a QueryMiddleware - it is however the end of the line for all QueryMiddleware -class RuleProcessor: - """Find and run a matching rule function. - - This class serves as the bottom of the `QueryMiddleware` stack. - - New in `2.0`. - """ - - def __init__(self, rules: List[RuleBase]) -> None: - """ - Args: - rules: rules to run against - """ - self.rules = rules - return - - def __call__(self, query: Query) -> Response: - for rule in self.rules: - rule_func = rule.get_func(query) - if rule_func is not None: - # self._info(f"Matched Rule: {rule}") - return coerce_to_response(rule_func(query)) - - # self._info("Did not match any rule") - return Response(error_code=dnslib.RCODE.NXDOMAIN) - - ## Raw Middleware ## ----------------------------------------------------------------------------- -class RawRecordMiddleware: +class RawMiddleware(MiddlewareBase[RawRecord, RawRecord]): """Middleware to be run against raw `dnslib.DNSRecord`s. - New in `2.0`. + New in `3.0`. """ - def __init__(self) -> None: - self.next_function: Optional[RawRecordMiddlewareCallable] = None - return - - def __call__(self, record: dnslib.DNSRecord) -> None: - if self.next_function is None: - raise RuntimeError("next_function is not set") - return self.process_record(record, self.next_function) - - def register_next_function(self, next_function: RawRecordMiddlewareCallable) -> None: - """Set the `next_function` of this middleware""" - if self.next_function is not None: - raise RuntimeError("next_function is already set") - self.next_function = next_function - return - - def process_record( - self, record: dnslib.DNSRecord, call_next: RawRecordMiddlewareCallable - ) -> dnslib.DNSRecord: - """Handle an incoming record. - - Child classes should override this function (if they do not this middleware will - simply pass the record onto the next function). - Args: - record: the incoming record - call_next: the next function in the chain - """ - return call_next(record) - - -class RawRecordExceptionHandlerMiddleware(RawRecordMiddleware): +class RawExceptionHandlerMiddleware(ExceptionHandlerBase[RawRecord, RawRecord]): """Middleware for handling exceptions originating from a `RawRecordMiddleware` stack. Allows registering handlers for individual `Exception` types. Only one handler can @@ -326,109 +315,36 @@ class RawRecordExceptionHandlerMiddleware(RawRecordMiddleware): When an exception is encountered, the middleware will search for the first handler that matches the class or parent class of the exception in method resolution order. If no handler - is registered will use this classes `self.default_exception_handler`. + is registered will use this classes `self.default_handler`. Danger: Important Exception handlers are expected to be robust - that is, they must always return correctly even if they internally encounter an `Exception`. - New in `2.0`. - Attributes: - exception_handlers: registered exception handlers - """ - - def __init__( - self, exception_handlers: Optional[Dict[Type[Exception], RawRecordExceptionHandler]] = None - ) -> None: - super().__init__() - self.exception_handlers: Dict[Type[Exception], RawRecordExceptionHandler] = ( - exception_handlers if exception_handlers is not None else {} - ) - return - - def process_record( - self, record: dnslib.DNSRecord, call_next: RawRecordMiddlewareCallable - ) -> dnslib.DNSRecord: - """Call the next function handling any exceptions that arise""" - try: - response = call_next(record) - except Exception as e: # pylint: disable=broad-except - handler = self.get_exception_handler(e) - response = handler(record, e) - return response - - def get_exception_handler(self, exception: Exception) -> RawRecordExceptionHandler: - """Get the exception handler for the given exception + handlers: registered exception handlers - Args: - exception: the exception we wish to handle - """ - for class_ in inspect.getmro(exception.__class__): - if class_ in self.exception_handlers: - return self.exception_handlers[class_] - # No exception handler found - use default handler - return self.default_exception_handler + New in `3.0`. + """ @staticmethod - def default_exception_handler( - record: dnslib.DNSRecord, exception: Exception - ) -> dnslib.DNSRecord: + def default_handler(request: RawRecord, exception: Exception) -> RawRecord: """Default exception handler""" # pylint: disable=unused-argument - response = record.reply() + response = request.reply() response.header.rcode = dnslib.RCODE.SERVFAIL return response -# Final Callable -# .............................................................................. -# This is not a RawRcordMiddleware - it is however the end of the line for all RawRecordMiddleware -class QueryMiddlewareProcessor: - """Convert an incoming DNS record and pass it to a `QueryMiddleware` stack. - - This class serves as the bottom of the `RawRcordMiddleware` stack. - - New in `2.0`. - """ - - def __init__(self, query_middleware: QueryMiddlewareCallable) -> None: - """ - Args: - query_middleware: the top of the middleware stack - """ - self.query_middleware = query_middleware - return - - def __call__(self, record: dnslib.DNSRecord) -> dnslib.DNSRecord: - response = record.reply() - - if record.header.opcode != dnslib.OPCODE.QUERY: - # self._info(f"Received non-query opcode: {record.header.opcode}") - # This server only response to DNS queries - response.header.rcode = dnslib.RCODE.NOTIMP - return response - - if len(record.questions) != 1: - # self._info(f"Received len(questions_ != 1 ({record.questions})") - # To simplify things we only respond if there is 1 question. - # This is apparently common amongst DNS server implementations. - # For more information see the responses to this SO question: - # https://stackoverflow.com/q/4082081 - response.header.rcode = dnslib.RCODE.REFUSED - return response - - try: - query = Query.from_dns_question(record.questions[0]) - except ValueError: - # self._warning(e) - response.header.rcode = dnslib.RCODE.FORMERR - return response - - result = self.query_middleware(query) - - response.add_answer(*result.get_answer_records()) - response.add_ar(*result.get_additional_records()) - response.add_auth(*result.get_authority_records()) - response.header.rcode = result.error_code - return response +### TYPE_CHECKING +### ============================================================================ +if TYPE_CHECKING and False: # pylint: disable=condition-evals-to-constant + # pylint: disable=undefined-variable + q1 = QueryExceptionHandlerMiddleware() + reveal_type(q1) + reveal_type(q1.handlers) + reveal_type(q1.default_handler) + r1 = RawExceptionHandlerMiddleware() + reveal_type(r1) + reveal_type(r1.handlers) + reveal_type(r1.default_handler) diff --git a/src/nserver/models.py b/src/nserver/models.py index a4626a3..dee7117 100644 --- a/src/nserver/models.py +++ b/src/nserver/models.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from typing import Optional, Union, List @@ -37,7 +40,7 @@ def __init__(self, qtype: str, name: str) -> None: return @classmethod - def from_dns_question(cls, question: dnslib.DNSQuestion) -> "Query": + def from_dns_question(cls, question: dnslib.DNSQuestion) -> Query: """Create a new query from a `dnslib.DNSQuestion`""" if question.qtype not in dnslib.QTYPE.forward: raise ValueError(f"Invalid QTYPE: {question.qtype}") @@ -106,14 +109,14 @@ def __repr__(self) -> str: def __str__(self) -> str: return self.__repr__() - def get_answer_records(self) -> List[dnslib.RD]: + def get_answer_records(self) -> list[dnslib.RD]: """Prepare resource records for answer section""" return [record.to_resource_record() for record in self.answers] - def get_additional_records(self) -> List[dnslib.RD]: + def get_additional_records(self) -> list[dnslib.RD]: """Prepare resource records for additional section""" return [record.to_resource_record() for record in self.additional] - def get_authority_records(self) -> List[dnslib.RD]: + def get_authority_records(self) -> list[dnslib.RD]: """Prepare resource records for authority section""" return [record.to_resource_record() for record in self.authority] diff --git a/src/nserver/records.py b/src/nserver/records.py index 71cef29..3510409 100644 --- a/src/nserver/records.py +++ b/src/nserver/records.py @@ -2,10 +2,13 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from ipaddress import IPv4Address, IPv6Address import re -from typing import Any, Union, Dict +from typing import Any ## Installed import dnslib @@ -36,7 +39,7 @@ def __init__(self, resource_name: str, ttl: int) -> None: type_name = self.__class__.__name__ self._qtype = getattr(dnslib.QTYPE, type_name) self._class = getattr(dnslib, type_name) # class means python class not RR CLASS - self._record_kwargs: Dict[str, Any] + self._record_kwargs: dict[str, Any] is_unsigned_int_size(ttl, 32, throw_error=True, value_name="ttl") self.ttl = ttl self.resource_name = resource_name @@ -56,7 +59,7 @@ def to_resource_record(self) -> dnslib.RR: class A(RecordBase): # pylint: disable=invalid-name """Ipv4 Address (`A`) Record.""" - def __init__(self, resource_name: str, ip: Union[str, IPv4Address], ttl: int = 300) -> None: + def __init__(self, resource_name: str, ip: str | IPv4Address, ttl: int = 300) -> None: """ Args: resource_name: DNS resource name @@ -77,7 +80,7 @@ def __init__(self, resource_name: str, ip: Union[str, IPv4Address], ttl: int = 3 class AAAA(RecordBase): """Ipv6 Address (`AAAA`) Record.""" - def __init__(self, resource_name: str, ip: Union[str, IPv6Address], ttl: int = 300) -> None: + def __init__(self, resource_name: str, ip: str | IPv6Address, ttl: int = 300) -> None: """ Args: resource_name: DNS resource name diff --git a/src/nserver/rules.py b/src/nserver/rules.py index 3abf5bc..a3b111d 100644 --- a/src/nserver/rules.py +++ b/src/nserver/rules.py @@ -2,9 +2,12 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library import re -from typing import Callable, List, Optional, Pattern, Union, Type +from typing import Callable, Pattern, Union, Type, List ## Installed import dnslib @@ -16,7 +19,7 @@ ### CONSTANTS ### ============================================================================ -ALL_QTYPES: List[str] = list(dnslib.QTYPE.reverse.keys()) +ALL_QTYPES: list[str] = list(dnslib.QTYPE.reverse.keys()) """All supported Query Types New in `2.0`. @@ -27,7 +30,33 @@ ### FUNCTIONS ### ============================================================================ -def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs) -> "RuleBase": +def coerce_to_response(result: RuleResult) -> Response: + """Convert some `RuleResult` to a `Response` + + Args: + result: the results to convert + + Raises: + TypeError: unsupported result type + + New in `3.0`. + """ + if isinstance(result, Response): + return result + + if result is None: + return Response() + + if isinstance(result, RecordBase) and result.__class__ is not RecordBase: + return Response(answers=result) + + if isinstance(result, list) and all(isinstance(item, RecordBase) for item in result): + return Response(answers=result) + + raise TypeError(f"Cannot process result: {result!r}") + + +def smart_make_rule(rule: Union[Type[RuleBase], str, Pattern], *args, **kwargs) -> RuleBase: """Create a rule using shorthand notation. The exact type of rule returned depends on what is povided by `rule`. @@ -76,7 +105,7 @@ def smart_make_rule(rule: "Union[Type[RuleBase], str, Pattern]", *args, **kwargs class RuleBase: """Base class for all Rules to inherit from.""" - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """From the given query return the function to run, if any. If no function should be run (i.e. because it does not match the rule), @@ -99,7 +128,7 @@ class StaticRule(RuleBase): def __init__( self, match_string: str, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -116,7 +145,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None @@ -147,7 +176,7 @@ class ZoneRule(RuleBase): def __init__( self, zone: str, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -165,7 +194,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if self.allowed_qtypes is not None and query.type not in self.allowed_qtypes: return None @@ -194,7 +223,7 @@ class RegexRule(RuleBase): def __init__( self, regex: Pattern, - allowed_qtypes: List[str], + allowed_qtypes: list[str], func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -219,7 +248,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None @@ -257,7 +286,7 @@ class WildcardStringRule(RuleBase): def __init__( self, wildcard_string: str, - allowed_qtypes: List, + allowed_qtypes: list, func: ResponseFunction, case_sensitive: bool = False, ) -> None: @@ -274,7 +303,7 @@ def __init__( self.case_sensitive = case_sensitive return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: """Same as parent class""" if query.type not in self.allowed_qtypes: return None diff --git a/src/nserver/server.py b/src/nserver/server.py index 0482914..6f88616 100644 --- a/src/nserver/server.py +++ b/src/nserver/server.py @@ -1,87 +1,145 @@ ### IMPORTS ### ============================================================================ -## Standard Library -import logging +## Future +from __future__ import annotations -# Note: Optional can only be replaced with `| None` in 3.10+ -from typing import List, Dict, Optional, Union, Type, Pattern +## Standard Library +from typing import TypeVar, Generic, Pattern ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application -from .exceptions import InvalidMessageError from .models import Query, Response -from .rules import smart_make_rule, RuleBase, ResponseFunction -from .settings import Settings -from .transport import TransportBase, UDPv4Transport, UDPv6Transport, TCPv4Transport +from .rules import coerce_to_response, smart_make_rule, RuleBase, ResponseFunction -from . import middleware +from . import middleware as m ### CONSTANTS ### ============================================================================ -TRANSPORT_MAP: Dict[str, Type[TransportBase]] = { - "UDPv4": UDPv4Transport, - "UDPv6": UDPv6Transport, - "TCPv4": TCPv4Transport, -} +# pylint: disable=invalid-name +T_middleware = TypeVar("T_middleware", bound=m.MiddlewareBase) +T_exception_handler = TypeVar("T_exception_handler", bound=m.ExceptionHandlerBase) +# pylint: enable=invalid-name ### Classes ### ============================================================================ -class _LoggingMixin: # pylint: disable=too-few-public-methods - """Self bound logging methods""" +class MiddlewareMixin(Generic[T_middleware, T_exception_handler]): + """Generic mixin for building a middleware stack in a server. - _logger: logging.Logger + Should not be used directly, instead use the servers that implement it: + `NameServer`, `RawNameServer`. - def _vvdebug(self, *args, **kwargs): - """Log very verbose debug message.""" + New in `3.0`. + """ - return self._logger.log(6, *args, **kwargs) + _exception_handler: T_exception_handler - def _vdebug(self, *args, **kwargs): - """Log verbose debug message.""" + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._middleware_stack_final: list[T_middleware] | None = None + self._middleware_stack_user: list[T_middleware] = [] + return - return self._logger.log(8, *args, **kwargs) + ## Middleware + ## ------------------------------------------------------------------------- + def middleware_is_prepared(self) -> bool: + """Check if the middleware has been prepared.""" + return self._middleware_stack_final is not None - def _debug(self, *args, **kwargs): - """Log debug message.""" + def append_middleware(self, middleware: T_middleware) -> None: + """Append this middleware to the middleware stack - return self._logger.debug(*args, **kwargs) + Args: + middleware: middleware to append + """ + if self.middleware_is_prepared(): + raise RuntimeError("Cannot append middleware once prepared") + self._middleware_stack_user.append(middleware) + return - def _info(self, *args, **kwargs): - """Log very verbose debug message.""" + def prepare_middleware(self) -> None: + """Prepare middleware for consumption - return self._logger.info(*args, **kwargs) + Child classes should wrap this method to set the `next_function` on the + final middleware in the stack. + """ + if self.middleware_is_prepared(): + raise RuntimeError("Middleware is already prepared") - def _warning(self, *args, **kwargs): - """Log warning message.""" + middleware_stack = self._prepare_middleware_stack() - return self._logger.warning(*args, **kwargs) + next_middleware: T_middleware | None = None - def _error(self, *args, **kwargs): - """Log an error message.""" + for middleware in middleware_stack[::-1]: + if next_middleware is not None: + middleware.set_next_function(next_middleware) + next_middleware = middleware - return self._logger.error(*args, **kwargs) + self._middleware_stack_final = middleware_stack + return - def _critical(self, *args, **kwargs): - """Log a critical message.""" + def _prepare_middleware_stack(self) -> list[T_middleware]: + """Create final stack of middleware. - return self._logger.critical(*args, **kwargs) + Child classes may override this method to customise the final middleware stack. + """ + return [self._exception_handler, *self._middleware_stack_user] # type: ignore[list-item] + @property + def middleware(self) -> list[T_middleware]: + """Accssor for this servers middleware. -class RulesContainer(_LoggingMixin): - """Base class for rules based functionality` + If the server has been prepared then returns a copy of the prepared middleware. + Otherwise returns a mutable list of the registered middleware. + """ + if self.middleware_is_prepared(): + return self._middleware_stack_final.copy() # type: ignore[union-attr] + return self._middleware_stack_user + + ## Exception Handler + ## ------------------------------------------------------------------------- + def register_exception_handler(self, *args, **kwargs) -> None: + """Shortcut for `self.exception_handler.set_handler`""" + self.exception_handler_middleware.set_handler(*args, **kwargs) + return - New in `2.0`. + @property + def exception_handler_middleware(self) -> T_exception_handler: + """Read only accessor for this server's middleware exception handler""" + return self._exception_handler + + def exception_handler(self, exception_class: type[Exception]): + """Decorator for registering a function as an raw exception handler + + Args: + exception_class: The `Exception` class to register this handler for + """ + + def decorator(func): + nonlocal exception_class + self.register_raw_exception_handler(exception_class, func) + return func + + return decorator + + +## Mixins +## ----------------------------------------------------------------------------- +class RulesMixin(LoggingMixin): + """Base class for rules based functionality` Attributes: - rules: registered rules + rules: reistered rules + + New in `3.0`. """ def __init__(self) -> None: super().__init__() - self.rules: List[RuleBase] = [] + self.rules: list[RuleBase] = [] return def register_rule(self, rule: RuleBase) -> None: @@ -90,11 +148,11 @@ def register_rule(self, rule: RuleBase) -> None: Args: rule: the rule to register """ - self._debug(f"Registered rule: {rule!r}") + self.vdebug(f"Registered rule: {rule!r}") self.rules.append(rule) return - def rule(self, rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs): + def rule(self, rule_: type[RuleBase] | str | Pattern, *args, **kwargs): """Decorator for registering a function using [`smart_make_rule`][nserver.rules.smart_make_rule]. Changed in `2.0`: This method now uses `smart_make_rule`. @@ -121,39 +179,112 @@ def decorator(func: ResponseFunction): return decorator -class ServerBase(RulesContainer): - """Base class for shared functionality between `NameServer` and `SubServer` +## Servers +## ----------------------------------------------------------------------------- +class RawNameServer( + MiddlewareMixin[m.RawMiddleware, m.RawExceptionHandlerMiddleware], LoggingMixin +): + """Server that handles raw `dnslib.DNSRecord` queries. - New in `2.0`. + This allows interacting with the underlying DNS messages from our dns library. + As such this server is implementation dependent and may change from time to time. - Attributes: - hook_middleware: hook middleware - exception_handler_middleware: Query exception handler middleware + In general you should use `NameServer` as it is implementation independent. + + New in `3.0`. """ - def __init__(self) -> None: + def __init__(self, nameserver: NameServer) -> None: + self._exception_handler = m.RawExceptionHandlerMiddleware() + super().__init__() + self.nameserver: NameServer = nameserver + self.logger = self.get_logger() + return + + def process_request(self, request: m.RawRecord) -> m.RawRecord: + """Process a request using this server. + + This will pass the request through the middleware stack. + """ + if not self.middleware_is_prepared(): + self.prepare_middleware() + return self.middleware[0](request) + + def send_request_to_nameserver(self, record: m.RawRecord) -> m.RawRecord: + """Send a request to the `NameServer` of this instance. + + Although this is the final step after passing a request through all middleware, + it can be called directly to avoid using middleware such as when testing. + """ + response = record.reply() + + if record.header.opcode != dnslib.OPCODE.QUERY: + self.debug(f"Received non-query opcode: {record.header.opcode}") + # This server only response to DNS queries + response.header.rcode = dnslib.RCODE.NOTIMP + return response + + if len(record.questions) != 1: + self.debug(f"Received len(questions_ != 1 ({record.questions})") + # To simplify things we only respond if there is 1 question. + # This is apparently common amongst DNS server implementations. + # For more information see the responses to this SO question: + # https://stackoverflow.com/q/4082081 + response.header.rcode = dnslib.RCODE.REFUSED + return response + + try: + query = Query.from_dns_question(record.questions[0]) + except ValueError: + # TODO: should we embed raw DNS query? Maybe this should be configurable. + self.warning("Failed to parse Query from request", exc_info=True) + response.header.rcode = dnslib.RCODE.FORMERR + return response + + result = self.nameserver.process_request(query) + + response.add_answer(*result.get_answer_records()) + response.add_ar(*result.get_additional_records()) + response.add_auth(*result.get_authority_records()) + response.header.rcode = result.error_code + return response + + def prepare_middleware(self) -> None: + super().prepare_middleware() + self.middleware[-1].set_next_function(self.send_request_to_nameserver) + return + + +class NameServer( + MiddlewareMixin[m.QueryMiddleware, m.QueryExceptionHandlerMiddleware], RulesMixin, LoggingMixin +): + """High level DNS Name Server for responding to DNS queries.""" + + def __init__(self, name: str) -> None: """ Args: name: The name of the server. This is used for internal logging. """ + self.name = name + self._exception_handler = m.QueryExceptionHandlerMiddleware() super().__init__() - self.hook_middleware = middleware.HookMiddleware() - self.exception_handler_middleware = middleware.ExceptionHandlerMiddleware() - - self._user_query_middleware: List[middleware.QueryMiddleware] = [] - self._query_middleware_stack: List[ - Union[middleware.QueryMiddleware, middleware.QueryMiddlewareCallable] - ] = [] + self.hooks = m.HookMiddleware() + self.logger = self.get_logger() return + def _prepare_middleware_stack(self) -> list[m.QueryMiddleware]: + stack = super()._prepare_middleware_stack() + stack.append(self.hooks) + return stack + ## Register Methods ## ------------------------------------------------------------------------- def register_subserver( - self, subserver: "SubServer", rule_: Union[Type[RuleBase], str, Pattern], *args, **kwargs + self, nameserver: NameServer, rule_: type[RuleBase] | str | Pattern, *args, **kwargs ) -> None: - """Register a `SubServer` using [`smart_make_rule`][nserver.rules.smart_make_rule]. + """Register a `NameServer` using [`smart_make_rule`][nserver.rules.smart_make_rule]. - New in `2.0`. + This allows for composing larger applications. Args: subserver: the `SubServer` to attach @@ -163,23 +294,25 @@ def register_subserver( Raises: ValueError: if `func` is provided in `kwargs`. + + New in `3.0`. """ if "func" in kwargs: raise ValueError("Must not provide `func` in kwargs") - self.register_rule(smart_make_rule(rule_, *args, func=subserver.entrypoint, **kwargs)) + self.register_rule(smart_make_rule(rule_, *args, func=nameserver.process_request, **kwargs)) return - def register_before_first_query(self, func: middleware.BeforeFirstQueryHook) -> None: + def register_before_first_query(self, func: m.BeforeFirstQueryHook) -> None: """Register a function to be run before the first query. Args: func: the function to register """ - self.hook_middleware.before_first_query.append(func) + self.hooks.before_first_query.append(func) return - def register_before_query(self, func: middleware.BeforeQueryHook) -> None: + def register_before_query(self, func: m.BeforeQueryHook) -> None: """Register a function to be run before every query. Args: @@ -187,50 +320,16 @@ def register_before_query(self, func: middleware.BeforeQueryHook) -> None: If `func` returns anything other than `None` will stop processing the incoming `Query` and continue to result processing with the return value. """ - self.hook_middleware.before_query.append(func) + self.hooks.before_query.append(func) return - def register_after_query(self, func: middleware.AfterQueryHook) -> None: + def register_after_query(self, func: m.AfterQueryHook) -> None: """Register a function to be run on the result of a query. Args: func: the function to register """ - self.hook_middleware.after_query.append(func) - return - - def register_middleware(self, query_middleware: middleware.QueryMiddleware) -> None: - """Add a `QueryMiddleware` to this server. - - New in `2.0`. - - Args: - query_middleware: the middleware to add - """ - if self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Cannot register middleware after stack is created") - self._user_query_middleware.append(query_middleware) - return - - def register_exception_handler( - self, exception_class: Type[Exception], handler: middleware.ExceptionHandler - ) -> None: - """Register an exception handler for the `QueryMiddleware` - - Only one handler can exist for a given exception type. - - New in `2.0`. - - Args: - exception_class: the type of exception to handle - handler: the function to call when handling an exception - """ - if exception_class in self.exception_handler_middleware.exception_handlers: - raise ValueError("Exception handler already exists for {exception_class}") - - self.exception_handler_middleware.exception_handlers[exception_class] = handler + self.hooks.after_query.append(func) return # Decorators @@ -242,7 +341,7 @@ def before_first_query(self): before any further processesing. """ - def decorator(func: middleware.BeforeFirstQueryHook): + def decorator(func: m.BeforeFirstQueryHook): self.register_before_first_query(func) return func @@ -254,7 +353,7 @@ def before_query(self): These functions are called before processing each query. """ - def decorator(func: middleware.BeforeQueryHook): + def decorator(func: m.BeforeQueryHook): self.register_before_query(func) return func @@ -267,304 +366,47 @@ def after_query(self): response. """ - def decorator(func: middleware.AfterQueryHook): + def decorator(func: m.AfterQueryHook): self.register_after_query(func) return func return decorator - def exception_handler(self, exception_class: Type[Exception]): - """Decorator for registering a function as an exception handler - - New in `2.0`. - - Args: - exception_class: The `Exception` class to register this handler for - """ - - def decorator(func: middleware.ExceptionHandler): - nonlocal exception_class - self.register_exception_handler(exception_class, func) - return func - - return decorator - - ## Internal Functions - ## ------------------------------------------------------------------------- - def _prepare_query_middleware_stack(self) -> None: - """Prepare the `QueryMiddleware` for this server.""" - if self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("QueryMiddleware stack already exists") - - middleware_stack: List[middleware.QueryMiddleware] = [ - self.exception_handler_middleware, - *self._user_query_middleware, - self.hook_middleware, - ] - rule_processor = middleware.RuleProcessor(self.rules) - - next_middleware: Optional[middleware.QueryMiddleware] = None - for query_middleware in middleware_stack[::-1]: - if next_middleware is None: - query_middleware.register_next_function(rule_processor) - else: - query_middleware.register_next_function(next_middleware) - next_middleware = query_middleware - - self._query_middleware_stack.extend(middleware_stack) - self._query_middleware_stack.append(rule_processor) - return - - -class NameServer(ServerBase): - """NameServer for responding to requests.""" - - # pylint: disable=too-many-instance-attributes - - def __init__(self, name: str, settings: Optional[Settings] = None) -> None: - """ - Args: - name: The name of the server. This is used for internal logging. - settings: settings to use with this `NameServer` instance - """ - super().__init__() - self.name = name - self._logger = logging.getLogger(f"nserver.i.nameserver.{self.name}") - - self.raw_exception_handler_middleware = middleware.RawRecordExceptionHandlerMiddleware() - self._user_raw_record_middleware: List[middleware.RawRecordMiddleware] = [] - self._raw_record_middleware_stack: List[ - Union[middleware.RawRecordMiddleware, middleware.RawRecordMiddlewareCallable] - ] = [] - - self.settings = settings if settings is not None else Settings() - - transport = TRANSPORT_MAP.get(self.settings.server_transport) - if transport is None: - raise ValueError( - f"Invalid settings.server_transport {self.settings.server_transport!r}" - ) - self.transport = transport(self.settings) - - self.shutdown_server = False - self.exit_code = 0 - return - - ## Register Methods - ## ------------------------------------------------------------------------- - def register_raw_middleware(self, raw_middleware: middleware.RawRecordMiddleware) -> None: - """Add a `RawRecordMiddleware` to this server. - - New in `2.0`. - - Args: - raw_middleware: the middleware to add - """ - if self._raw_record_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Cannot register middleware after stack is created") - self._user_raw_record_middleware.append(raw_middleware) - return - - def register_raw_exception_handler( - self, exception_class: Type[Exception], handler: middleware.RawRecordExceptionHandler - ) -> None: - """Register a raw exception handler for the `RawRecordMiddleware`. - - Only one handler can exist for a given exception type. - - New in `2.0`. - - Args: - exception_class: the type of exception to handle - handler: the function to call when handling an exception - """ - if exception_class in self.raw_exception_handler_middleware.exception_handlers: - raise ValueError("Exception handler already exists for {exception_class}") - - self.raw_exception_handler_middleware.exception_handlers[exception_class] = handler - return - - # Decorators - # .......................................................................... - def raw_exception_handler(self, exception_class: Type[Exception]): - """Decorator for registering a function as an raw exception handler - - New in `2.0`. - - Args: - exception_class: The `Exception` class to register this handler for - """ - - def decorator(func: middleware.RawRecordExceptionHandler): - nonlocal exception_class - self.register_raw_exception_handler(exception_class, func) - return func - - return decorator - - ## Public Methods - ## ------------------------------------------------------------------------- - def run(self) -> int: - """Start running the server - - Returns: - `exit_code`, `0` if exited normally - """ - # Setup Logging - console_logger = logging.StreamHandler() - console_logger.setLevel(self.settings.console_log_level) - - console_formatter = logging.Formatter( - "[{asctime}][{levelname}][{name}] {message}", style="{" - ) - - console_logger.setFormatter(console_formatter) - - self._logger.addHandler(console_logger) - self._logger.setLevel(min(self.settings.console_log_level, self.settings.file_log_level)) - - # Start Server - # TODO: Do we want to recreate the transport instance or do we assume that - # transport.shutdown_server puts it back into a ready state? - # We could make this configurable? :thonking: - - self._info(f"Starting {self.transport}") - try: - self._prepare_middleware_stacks() - self.transport.start_server() - except Exception as e: # pylint: disable=broad-except - self._critical(e) - self.exit_code = 1 - return self.exit_code - - # Process Requests - error_count = 0 - while True: - if self.shutdown_server: - break - try: - message = self.transport.receive_message() - response = self._process_dns_record(message.message) - message.response = response - self.transport.send_message_response(message) - except InvalidMessageError as e: - self._warning(f"{e}") - except Exception as e: # pylint: disable=broad-except - self._error(f"Uncaught error occured. {e}", exc_info=True) - error_count += 1 - if error_count >= self.settings.max_errors: - self._critical(f"Max errors hit ({error_count})") - self.shutdown_server = True - self.exit_code = 1 - except KeyboardInterrupt: - self._info("KeyboardInterrupt received.") - self.shutdown_server = True - - # Stop Server - self._info("Shutting down server") - self.transport.stop_server() - - # Teardown Logging - self._logger.removeHandler(console_logger) - return self.exit_code - ## Internal Functions ## ------------------------------------------------------------------------- - def _process_dns_record(self, message: dnslib.DNSRecord) -> dnslib.DNSRecord: - """Process the given DNSRecord by sending it into the `RawRecordMiddleware` stack. - - Args: - message: the DNS query to process - - Returns: - the DNS response - """ - if self._raw_record_middleware_stack is None: - raise RuntimeError( - "RawRecordMiddleware stack does not exist. Have you called _prepare_middleware?" - ) - return self._raw_record_middleware_stack[0](message) - - def _prepare_middleware_stacks(self) -> None: - """Prepare all middleware for this server.""" - self._prepare_query_middleware_stack() - self._prepare_raw_record_middleware_stack() - return - - def _prepare_raw_record_middleware_stack(self) -> None: - """Prepare the `RawRecordMiddleware` for this server.""" - if not self._query_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("Must prepare QueryMiddleware stack first") - - if self._raw_record_middleware_stack: - # Note: we can use truthy expression as once processed there will always be at - # least one item in the stack - raise RuntimeError("RawRecordMiddleware stack already exists") - - middleware_stack: List[middleware.RawRecordMiddleware] = [ - self.raw_exception_handler_middleware, - *self._user_raw_record_middleware, - ] - - query_middleware_processor = middleware.QueryMiddlewareProcessor( - self._query_middleware_stack[0] - ) - - next_middleware: Optional[middleware.RawRecordMiddleware] = None - for raw_middleware in middleware_stack[::-1]: - if next_middleware is None: - raw_middleware.register_next_function(query_middleware_processor) - else: - raw_middleware.register_next_function(next_middleware) - next_middleware = raw_middleware - - self._raw_record_middleware_stack.extend(middleware_stack) - self._raw_record_middleware_stack.append(query_middleware_processor) + def process_request(self, query: Query) -> Response: + """Process a query passing it through all middleware.""" + if not self.middleware_is_prepared(): + self.prepare_middleware() + return self.middleware[0](query) + + def prepare_middleware(self) -> None: + super().prepare_middleware() + self.middleware[-1].set_next_function(self.send_query_to_rules) return + def send_query_to_rules(self, query: Query) -> Response: + """Send a query to be processed by the rules of this instance. -class SubServer(ServerBase): - """Class that can replicate many of the functions of a `NameServer`. - - They can be used to construct or extend applications. - - A `SubServer` maintains it's own `QueryMiddleware` stack and list of rules. - - New in `2.0`. - """ - - def __init__(self, name: str) -> None: + Although intended to be the final step after passing a query through all middleware, + this method can be used to bypass the middleware of this server such as for testing. """ - Args: - name: The name of the server. This is used for internal logging. - """ - super().__init__() - self.name = name - self._logger = logging.getLogger(f"nserver.i.subserver.{self.name}") - return - - def entrypoint(self, query: Query) -> Response: - """Entrypoint into this `SubServer`. + for rule in self.rules: + rule_func = rule.get_func(query) + if rule_func is not None: + self.debug(f"Matched Rule: {rule}") + return coerce_to_response(rule_func(query)) - This method should be passed to rules as the function to run. - """ - if not self._query_middleware_stack: - self._prepare_query_middleware_stack() - return self._query_middleware_stack[0](query) + self.debug("Did not match any rule") + return Response(error_code=dnslib.RCODE.NXDOMAIN) -class Blueprint(RulesContainer, RuleBase): +class Blueprint(RulesMixin, RuleBase, LoggingMixin): """A container for rules that can be registered onto a server It can be registered as normal rule: `server.register_rule(blueprint_rule)` - New in `2.0`. + New in `3.0`. """ def __init__(self, name: str) -> None: @@ -574,14 +416,14 @@ def __init__(self, name: str) -> None: """ super().__init__() self.name = name - self._logger = logging.getLogger(f"nserver.i.blueprint.{self.name}") + self.logger = self.get_logger() return - def get_func(self, query: Query) -> Optional[ResponseFunction]: + def get_func(self, query: Query) -> ResponseFunction | None: for rule in self.rules: func = rule.get_func(query) if func is not None: - self._debug(f"matched {rule}") + self.debug(f"matched {rule}") return func - self._debug("did not match any rule") + self.debug("did not match any rule") return None diff --git a/src/nserver/settings.py b/src/nserver/settings.py deleted file mode 100644 index bd439cf..0000000 --- a/src/nserver/settings.py +++ /dev/null @@ -1,35 +0,0 @@ -### IMPORTS -### ============================================================================ -## Standard Library -from dataclasses import dataclass -import logging - -## Installed - -## Application - - -### CLASSES -### ============================================================================ -@dataclass -class Settings: - """Dataclass for NameServer settings - - Attributes: - server_transport: What `Transport` to use. See `nserver.server.TRANSPORT_MAP` for options. - server_address: What address `server_transport` will bind to. - server_port: what port `server_port` will bind to. - """ - - server_transport: str = "UDPv4" - server_address: str = "localhost" - server_port: int = 9953 - console_log_level: int = logging.INFO - file_log_level: int = logging.INFO - max_errors: int = 5 - - # Not implemented, ideas for useful things - # debug: bool = False # Put server into "debug mode" (e.g. hot reload) - # health_check: bool = False # provde route for health check - # stats: bool = False # provide route for retrieving operational stats - # remote_admin: bool = False # allow remote shutdown restart etc? diff --git a/src/nserver/transport.py b/src/nserver/transport.py index f431a36..04f9c74 100644 --- a/src/nserver/transport.py +++ b/src/nserver/transport.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library from collections import deque from dataclasses import dataclass @@ -8,16 +11,15 @@ import socket import struct import time +from typing import Deque, NewType, Any, cast -# Note: Union can only be replaced with `X | Y` in 3.10+ -from typing import Tuple, Optional, Dict, List, Deque, NewType, Any, Union, cast ## Installed import dnslib +from pillar.logging import LoggingMixin ## Application from .exceptions import InvalidMessageError -from .settings import Settings ### CONSTANTS @@ -48,7 +50,7 @@ class TcpState(enum.IntEnum): ### FUNCTIONS ### ============================================================================ -def get_tcp_info(connection: socket.socket) -> Tuple: +def get_tcp_info(connection: socket.socket) -> tuple: """Get `socket.TCP_INFO` from socket Args: @@ -111,9 +113,9 @@ class MessageContainer: # pylint: disable=too-few-public-methods def __init__( self, raw_data: bytes, - transport: "TransportBase", + transport: TransportBase, transport_data: Any, - remote_client: Union[str, Tuple[str, int]], + remote_client: str | tuple[str, int], ): """Create new message container @@ -148,7 +150,7 @@ def __init__( self.transport = transport self.transport_data = transport_data self.remote_client = remote_client - self.response: Optional[dnslib.DNSRecord] = None + self.response: dnslib.DNSRecord | None = None return def get_response_bytes(self): @@ -160,16 +162,11 @@ def get_response_bytes(self): ## Transport Classes ## ----------------------------------------------------------------------------- -class TransportBase: +class TransportBase(LoggingMixin): """Base class for all transports""" - def __init__(self, settings: Settings) -> None: - """ - Args: - settings: settings of the server this transport is attached to - """ - self.settings = settings - # TODO: setup logging + def __init__(self) -> None: + self.logger = self.get_logger() return def start_server(self, timeout: int = 60) -> None: @@ -199,7 +196,7 @@ class UDPMessageData: remote_address: UDP peername that this message was received from """ - remote_address: Tuple[str, int] + remote_address: tuple[str, int] class UDPv4Transport(TransportBase): @@ -207,10 +204,10 @@ class UDPv4Transport(TransportBase): _SOCKET_AF = socket.AF_INET - def __init__(self, settings: Settings): - super().__init__(settings) - self.address = self.settings.server_address - self.port = self.settings.server_port + def __init__(self, address: str, port: int): + super().__init__() + self.address = address + self.port = port self.socket = socket.socket(self._SOCKET_AF, socket.SOCK_DGRAM) return @@ -284,7 +281,7 @@ class CachedConnection: """ connection: socket.socket - remote_address: Tuple[str, int] + remote_address: tuple[str, int] last_data_time: float selector_key: selectors.SelectorKey cache_key: CacheKey @@ -306,17 +303,17 @@ class TCPv4Transport(TransportBase): CONNECTION_CACHE_TARGET = int(CONNECTION_CACHE_LIMIT * CONNECTION_CACHE_VACUUM_PERCENT) CONNECTION_CACHE_CLEAN_INTERVAL = 10 # seconds - def __init__(self, settings: Settings) -> None: - super().__init__(settings) - self.address = self.settings.server_address - self.port = self.settings.server_port + def __init__(self, address: str, port: int) -> None: + super().__init__() + self.address = address + self.port = port self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.setblocking(False) # Allow taking over of socket when in TIME_WAIT (i.e. previously released) self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.selector = selectors.DefaultSelector() - self.cached_connections: Dict[CacheKey, CachedConnection] = {} + self.cached_connections: dict[CacheKey, CachedConnection] = {} self.last_cache_clean = 0.0 self.connection_queue: Deque[socket.socket] = deque() @@ -380,7 +377,7 @@ def stop_server(self) -> None: def __repr__(self): return f"{self.__class__.__name__}(address={self.address!r}, port={self.port!r})" - def _get_next_connection(self) -> Tuple[socket.socket, Tuple[str, int]]: + def _get_next_connection(self) -> tuple[socket.socket, tuple[str, int]]: """Get the next connection that is ready to receive data on.""" while not self.connection_queue: # loop until connection is ready for execution @@ -471,7 +468,7 @@ def _connection_viable(connection: socket.socket) -> bool: def _cleanup_cached_connections(self) -> None: "Cleanup cached connections" now = time.time() - cache_clear: List[CacheKey] = [] + cache_clear: list[CacheKey] = [] for cache_key, cache in self.cached_connections.items(): if now - cache.last_data_time > self.CONNECTION_KEEPALIVE_LIMIT: if cache.connection not in self.connection_queue: @@ -485,7 +482,7 @@ def _cleanup_cached_connections(self) -> None: for cache_key in cache_clear: self._remove_connection(cache_key=cache_key) - quiet_connections: List[CachedConnection] = [] + quiet_connections: list[CachedConnection] = [] cached_connections_len = len(self.cached_connections) cache_clear = [] @@ -516,7 +513,7 @@ def _cleanup_cached_connections(self) -> None: return def _remove_connection( - self, connection: Optional[socket.socket] = None, cache_key: Optional[CacheKey] = None + self, connection: socket.socket | None = None, cache_key: CacheKey | None = None ) -> None: """Remove a connection from the server (closing it in the process) diff --git a/src/nserver/util.py b/src/nserver/util.py index 261add3..191fad7 100644 --- a/src/nserver/util.py +++ b/src/nserver/util.py @@ -1,5 +1,8 @@ ### IMPORTS ### ============================================================================ +## Future +from __future__ import annotations + ## Standard Library ## Installed diff --git a/tests/test_blueprint.py b/tests/test_blueprint.py index ed9df03..30a6fa6 100644 --- a/tests/test_blueprint.py +++ b/tests/test_blueprint.py @@ -7,7 +7,7 @@ import dnslib import pytest -from nserver import NameServer, Blueprint, Query, A +from nserver import NameServer, RawNameServer, Blueprint, Query, A ## Application @@ -18,6 +18,7 @@ blueprint_1 = Blueprint("blueprint_1") blueprint_2 = Blueprint("blueprint_2") blueprint_3 = Blueprint("blueprint_3") +raw_server = RawNameServer(server) ## Rules @@ -36,8 +37,6 @@ def dummy_rule(query: Query) -> A: server.register_rule(blueprint_2) blueprint_2.register_rule(blueprint_3) -server._prepare_middleware_stacks() - ### TESTS ### ============================================================================ @@ -45,7 +44,7 @@ def dummy_rule(query: Query) -> A: ## ----------------------------------------------------------------------------- @pytest.mark.parametrize("question", ["s.com", "b1.com", "b2.com", "b3.b2.com"]) def test_response(question: str): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_server.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == question @@ -54,7 +53,7 @@ def test_response(question: str): @pytest.mark.parametrize("question", ["miss.s.com", "miss.b1.com", "miss.b2.com", "miss.b3.b2.com"]) def test_nxdomain(question: str): - response = server._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_server.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 0 assert response.header.rcode == dnslib.RCODE.NXDOMAIN return diff --git a/tests/test_server.py b/tests/test_server.py index f8561f4..9a12db0 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -10,7 +10,7 @@ import dnslib import pytest -from nserver import NameServer, Query, Response, A +from nserver import NameServer, RawNameServer, Query, Response, A ## Application @@ -18,6 +18,7 @@ ### ============================================================================ IP = "127.0.0.1" server = NameServer("tests") +raw_server = RawNameServer(server) ## Rules @@ -106,11 +107,10 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) -server.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) +raw_server.register_exception_handler(ErrorForTesting, raw_record_error_handler) ## Get server ready ## ----------------------------------------------------------------------------- -server._prepare_middleware_stacks() ### TESTS @@ -118,13 +118,13 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> ## NameServer._process_dns_record ## ----------------------------------------------------------------------------- def test_none_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("none-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("none-response.com")) assert len(response.rr) == 0 return def test_response_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("response-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("response-response.com")) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == "response-response.com." @@ -132,7 +132,7 @@ def test_response_response(): def test_record_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("record-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("record-response.com")) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == "record-response.com." @@ -140,7 +140,7 @@ def test_record_response(): def test_multi_record_response(): - response = server._process_dns_record(dnslib.DNSRecord.question("multi-record-response.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("multi-record-response.com")) assert len(response.rr) == 2 for record in response.rr: assert record.rtype == 1 @@ -160,12 +160,12 @@ def test_multi_record_response(): ) def test_hook_call_count(hook, call_count): # Setup - server.hook_middleware.before_first_query_run = False + server.hooks.before_first_query_run = False hook.reset_mock() # Test for _ in range(5): - response = server._process_dns_record(dnslib.DNSRecord.question("dummy.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("dummy.com")) # Ensure respone returns and unchanged assert len(response.rr) == 1 assert response.rr[0].rtype == 1 @@ -183,7 +183,7 @@ def test_query_error_handler(): raw_record_error_handler.reset_mock() # Test - response = server._process_dns_record(dnslib.DNSRecord.question("throw-error.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("throw-error.com")) assert len(response.rr) == 0 assert response.header.get_rcode() == dnslib.RCODE.SERVFAIL @@ -199,7 +199,7 @@ def test_raw_record_error_handler(): raw_record_error_handler.reset_mock() # Test - response = server._process_dns_record(dnslib.DNSRecord.question("throw-another-error.com")) + response = raw_server.process_request(dnslib.DNSRecord.question("throw-another-error.com")) assert len(response.rr) == 0 assert response.header.get_rcode() == dnslib.RCODE.SERVFAIL diff --git a/tests/test_subserver.py b/tests/test_subserver.py index 601e51e..ee5c9b9 100644 --- a/tests/test_subserver.py +++ b/tests/test_subserver.py @@ -10,8 +10,7 @@ import dnslib import pytest -from nserver import NameServer, SubServer, Query, Response, ALL_QTYPES, ZoneRule, A -from nserver.server import ServerBase +from nserver import NameServer, RawNameServer, Query, Response, ALL_QTYPES, ZoneRule, A ## Application @@ -19,9 +18,10 @@ ### ============================================================================ IP = "127.0.0.1" nameserver = NameServer("test_subserver") -subserver_1 = SubServer("subserver_1") -subserver_2 = SubServer("subserver_2") -subserver_3 = SubServer("subserver_3") +subserver_1 = NameServer("subserver_1") +subserver_2 = NameServer("subserver_2") +subserver_3 = NameServer("subserver_3") +raw_nameserver = RawNameServer(nameserver) ## Rules @@ -36,7 +36,7 @@ def dummy_rule(query: Query) -> A: ## Hooks ## ----------------------------------------------------------------------------- -def register_hooks(server: ServerBase) -> None: +def register_hooks(server: NameServer) -> None: server.register_before_first_query(unittest.mock.MagicMock(wraps=lambda: None)) server.register_before_query(unittest.mock.MagicMock(wraps=lambda q: None)) server.register_after_query(unittest.mock.MagicMock(wraps=lambda r: r)) @@ -44,11 +44,11 @@ def register_hooks(server: ServerBase) -> None: @no_type_check -def reset_hooks(server: ServerBase) -> None: - server.hook_middleware.before_first_query_run = False - server.hook_middleware.before_first_query[0].reset_mock() - server.hook_middleware.before_query[0].reset_mock() - server.hook_middleware.after_query[0].reset_mock() +def reset_hooks(server: NameServer) -> None: + server.hooks.before_first_query_run = False + server.hooks.before_first_query[0].reset_mock() + server.hooks.before_query[0].reset_mock() + server.hooks.after_query[0].reset_mock() return @@ -61,10 +61,10 @@ def reset_all_hooks() -> None: @no_type_check -def check_hook_call_count(server: ServerBase, bfq_count: int, bq_count: int, aq_count: int) -> None: - assert server.hook_middleware.before_first_query[0].call_count == bfq_count - assert server.hook_middleware.before_query[0].call_count == bq_count - assert server.hook_middleware.after_query[0].call_count == aq_count +def check_hook_call_count(server: NameServer, bfq_count: int, bq_count: int, aq_count: int) -> None: + assert server.hooks.before_first_query[0].call_count == bfq_count + assert server.hooks.before_query[0].call_count == bq_count + assert server.hooks.after_query[0].call_count == aq_count return @@ -111,24 +111,12 @@ def bad_error_handler(query: Query, exception: Exception) -> Response: nameserver.register_exception_handler(ThrowAnotherError, bad_error_handler) -def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> dnslib.DNSRecord: - # pylint: disable=unused-argument - response = record.reply() - response.header.rcode = dnslib.RCODE.SERVFAIL - return response - - -raw_record_error_handler = unittest.mock.MagicMock(wraps=_raw_record_error_handler) -nameserver.register_raw_exception_handler(ErrorForTesting, raw_record_error_handler) - ## Get server ready ## ----------------------------------------------------------------------------- nameserver.register_subserver(subserver_1, ZoneRule, "sub1.com", ALL_QTYPES) nameserver.register_subserver(subserver_2, ZoneRule, "sub2.com", ALL_QTYPES) subserver_2.register_subserver(subserver_3, ZoneRule, "sub3.sub2.com", ALL_QTYPES) -nameserver._prepare_middleware_stacks() - ### TESTS ### ============================================================================ @@ -136,7 +124,7 @@ def _raw_record_error_handler(record: dnslib.DNSRecord, exception: Exception) -> ## ----------------------------------------------------------------------------- @pytest.mark.parametrize("question", ["s.com", "sub1.com", "sub2.com", "sub3.sub2.com"]) def test_response(question: str): - response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == question @@ -147,7 +135,7 @@ def test_response(question: str): "question", ["miss.s.com", "miss.sub1.com", "miss.sub2.com", "miss.sub3.sub2.com"] ) def test_nxdomain(question: str): - response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 0 assert response.header.rcode == dnslib.RCODE.NXDOMAIN return @@ -174,7 +162,7 @@ def test_hooks(question: str, hook_counts: List[int]): ## Test for _ in range(5): - response = nameserver._process_dns_record(dnslib.DNSRecord.question(question)) + response = raw_nameserver.process_request(dnslib.DNSRecord.question(question)) assert len(response.rr) == 1 assert response.rr[0].rtype == 1 assert response.rr[0].rname == question From bad0709ce90ec68c2f4b82ac3072d0fe2f182e7a Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 10 Nov 2024 15:03:03 +1100 Subject: [PATCH 12/16] Use typing_extensions --- pyproject.toml | 1 + src/nserver/middleware.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 62c1fdf..c333cc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "dnslib", "pillar~=0.3", "tldextract", + "typing-extensions;python_version<'3.10'", ] # Extra information diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 66d7cab..575d509 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -7,6 +7,12 @@ import inspect import threading from typing import TYPE_CHECKING, Callable, Generic, TypeVar, TypeAlias +import sys + +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias ## Installed import dnslib From e60e98b07d3eb43c9cb6432be4ad4ad6e088f72e Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 10 Nov 2024 15:07:00 +1100 Subject: [PATCH 13/16] remove repeat import --- src/nserver/middleware.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 575d509..25682c2 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -6,7 +6,7 @@ ## Standard Library import inspect import threading -from typing import TYPE_CHECKING, Callable, Generic, TypeVar, TypeAlias +from typing import TYPE_CHECKING, Callable, Generic, TypeVar import sys if sys.version_info < (3, 10): From 70e25fec428bc520a97e0c48a55b5db29ea4f62e Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 10 Nov 2024 15:12:04 +1100 Subject: [PATCH 14/16] Change import order --- src/nserver/middleware.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/nserver/middleware.py b/src/nserver/middleware.py index 25682c2..e27df48 100644 --- a/src/nserver/middleware.py +++ b/src/nserver/middleware.py @@ -9,11 +9,6 @@ from typing import TYPE_CHECKING, Callable, Generic, TypeVar import sys -if sys.version_info < (3, 10): - from typing_extensions import TypeAlias -else: - from typing import TypeAlias - ## Installed import dnslib from pillar.logging import LoggingMixin @@ -22,6 +17,13 @@ from .models import Query, Response from .rules import coerce_to_response, RuleResult +## Special +if sys.version_info < (3, 10): + from typing_extensions import TypeAlias +else: + from typing import TypeAlias + + ### CONSTANTS ### ============================================================================ # pylint: disable=invalid-name From 7eeac9bcfd70d81e6d8d98296061b1d646baf6b1 Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Sun, 10 Nov 2024 16:08:14 +1100 Subject: [PATCH 15/16] Add CLI entrypoint --- pyproject.toml | 3 +++ src/nserver/__init__.py | 2 ++ src/nserver/__main__.py | 18 ++++++++++++++++++ src/nserver/cli.py | 18 ++++++++++-------- 4 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 src/nserver/__main__.py diff --git a/pyproject.toml b/pyproject.toml index c333cc3..6dd432d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,9 @@ dev = [ "mike", ] +[project.scripts] +nserver = "nserver.__main__:main" + [tool.setuptools.package-data] nserver = ["py.typed"] diff --git a/src/nserver/__init__.py b/src/nserver/__init__.py index 0758315..dde0054 100644 --- a/src/nserver/__init__.py +++ b/src/nserver/__init__.py @@ -1,3 +1,5 @@ +### IMPORTS +### ============================================================================ from .models import Query, Response from .rules import ALL_QTYPES, StaticRule, ZoneRule, RegexRule, WildcardStringRule from .records import A, AAAA, NS, CNAME, PTR, SOA, MX, TXT, CAA diff --git a/src/nserver/__main__.py b/src/nserver/__main__.py new file mode 100644 index 0000000..4039844 --- /dev/null +++ b/src/nserver/__main__.py @@ -0,0 +1,18 @@ +### IMPORTS +### ============================================================================ +from .cli import CliApplication + + +### FUNCTIONS +### ============================================================================ +def main(): + "CLI Entrypoint" + app = CliApplication() + app.run() + return app + + +### MAIN +### ============================================================================ +if __name__ == "__main__": + main() diff --git a/src/nserver/cli.py b/src/nserver/cli.py index 96be94b..88da49b 100644 --- a/src/nserver/cli.py +++ b/src/nserver/cli.py @@ -6,6 +6,8 @@ ## Standard Library import argparse import importlib +import os +import pydoc ## Installed import pillar.application @@ -38,6 +40,7 @@ def get_argument_parser(self) -> argparse.ArgumentParser: parser.add_argument( "--server", action="store", + required=True, help=( "Import path of server / factory to run in the form of " "package.module.path:attribute" @@ -100,7 +103,13 @@ def main(self) -> int | None: def get_server(self) -> NameServer | RawNameServer: """Factory for getting the server to run based on current settings""" module_path, attribute_path = self.args.server.split(":") - obj: object = importlib.import_module(module_path) + + obj: object + if os.path.isfile(module_path): + # Ref: https://stackoverflow.com/a/68361215/12281814 + obj = pydoc.importfile(module_path) + else: + obj = importlib.import_module(module_path) for attribute_name in attribute_path.split("."): obj = getattr(obj, attribute_name) @@ -123,10 +132,3 @@ def get_application(self) -> BaseApplication: self.args.transport(self.args.host, self.args.port), ) return application - - -### MAIN -### ============================================================================ -if __name__ == "__main__": - app = CliApplication() - app.run() From 0895a761d2c327287db8627643877ad2c27cd8df Mon Sep 17 00:00:00 2001 From: Nicholas Hairs Date: Wed, 13 Nov 2024 18:19:27 +1100 Subject: [PATCH 16/16] Update docs --- docs/error-handling.md | 9 +++------ docs/index.md | 5 ++++- docs/middleware.md | 37 ++++++++++++++++++------------------- docs/quickstart.md | 20 +++++++------------- docs/subserver-blueprint.md | 30 +++++++++++++----------------- 5 files changed, 45 insertions(+), 56 deletions(-) diff --git a/docs/error-handling.md b/docs/error-handling.md index 86243dc..2299e4c 100644 --- a/docs/error-handling.md +++ b/docs/error-handling.md @@ -1,11 +1,8 @@ # Error Handling -Custom exception handling is handled through the [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] and [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. +Custom exception handling is handled through the [`ExceptionHandlerMiddleware`][nserver.middleware.ExceptionHandlerMiddleware] and [`RawExceptionHandlerMiddleware`][nserver.middleware.RawExceptionHandlerMiddleware] [Middleware][middleware]. These middleware will catch any `Exception`s raised by their respective middleware stacks. -!!! note - Error handling requires `nserver>=2.0` - -In general you are probably able to use the `ExceptionHandlerMiddleware` as the `RawRecordExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawRecordMiddleware` or broken exception handlers in the `ExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `ExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `ExceptionHandlerMiddleware`. +In general you are probably able to use the `ExceptionHandlerMiddleware` as the `RawExceptionHandlerMiddleware` is only needed to catch exceptions resulting from `RawMiddleware` or broken exception handlers in the `ExceptionHandlerMiddleware`. If you only write `QueryMiddleware` and your `ExceptionHandlerMiddleware` handlers never raise exceptions then you'll be good to go with just the `ExceptionHandlerMiddleware`. Both of these middleware have a default exception handler that will be used for anything not matching a registered handler. The default handler can be overwritten by registering a handler for the `Exception` class. @@ -16,7 +13,7 @@ Handlers are chosen by finding a handler for the most specific parent class of t ## Registering Exception Handlers -Exception handlers can be registered to `NameServer` and `SubServer` instances using either their `@[raw_]exception_handler` decorators or their `register_[raw_]exception_handler` methods. +Exception handlers can be registered to `NameServer` and `RawNameSeerver` instances using either their `@exception_handler` decorators or their `register_exception_handler` methods. ```python import dnslib diff --git a/docs/index.md b/docs/index.md index 50c5ab4..bc00750 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,7 +30,7 @@ NServer has been inspired by easy to use high level frameworks such as [Flask](h Follow our [Quickstart Guide](quickstart.md). -```python title="TLDR" +```python title="tldr.py" from nserver import NameServer, Query, A server = NameServer("example") @@ -43,6 +43,9 @@ if __name__ == "__main__": server.run() ``` +```bash +nserver --server tldr.py:server +``` ## Bugs, Feature Requests etc Please [submit an issue on github](https://github.com/nhairs/nserver/issues). diff --git a/docs/middleware.md b/docs/middleware.md index 41987ff..9e924da 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -3,11 +3,11 @@ Middleware can be used to modify the behaviour of a server seperate to the individual rules that are registered to the server. Middleware is run on all requests and can modify both the input and response of a request. !!! note - Middleware requires `nserver>=2.0` + Middleware requires `nserver>=3.0` ## Middleware Stacks -Middleware operates in a stack with each middleware calling the middleware below it until one returns and the result is propagated back up the chain. NServer uses two stacks, the outmost stack deals with raw DNS records (`RawRecordMiddleware`), which will eventually convert the record to a `Query` which will then be passed to the main `QueryMiddleware` stack. +Middleware operates in a stack with each middleware calling the middleware below it until one returns and the result is propagated back up the chain. NServer uses two stacks, the outmost stack deals with raw DNS records (`RawMiddleware`), which will eventually convert the record to a `Query` which will then be passed to the main `QueryMiddleware` stack. Middleware can be added to the application until it is run. Once the server begins running the middleware cannot be modified. The ordering of middleware is kept in the order in which it is added to the server; that is the first middleware registered will be called before the second and so on. @@ -65,33 +65,34 @@ Once processed the `QueryMiddleware` stack will look as follows: - `` - [`HookMiddleware`][nserver.middleware.HookMiddleware] - Runs hooks registered to the server. This can be considered a simplified version of middleware. -- [`RuleProcessor`][nserver.middleware.RuleProcessor] - - The entry point into our rule processing. -## `RawRecordMiddleware` +## `RawMiddleware` -[`RawRecordMiddleware`][nserver.middleware.RawRecordMiddleware] allows for modifying the raw `dnslib.DNSRecord`s that are recevied and sent by the server. +[`RawMiddleware`][nserver.middleware.RawMiddleware] allows for modifying the raw `dnslib.DNSRecord`s that are recevied and sent by the server. -### Registering `RawRecordMiddleware` +### Registering `RawMiddleware` -`RawRecordMiddleware` can be registered to `NameServer` instances using their `register_raw_middleware` method. +`RawMiddleware` can be registered to `RawNameServer` instances using their `register_middleware` method. ```python # ... -from nserver.middleware import RawRecordMiddleware +from nserver import RawNameServer +from nserver.middleware import RawMiddleware -server.register_raw_middleware(RawRecordMiddleware()) +raw_server = RawNameServer(server) + +server.register_middleware(RawMiddleware()) ``` -### Creating your own `RawRecordMiddleware` +### Creating your own `RawMiddleware` -Using an unmodified `RawRecordMiddleware` isn't very interesting as it just passes the request onto the next middleware. To add your own middleware you should subclass `RawRecordMiddleware` and override the `process_record` method. +Using an unmodified `RawMiddleware` isn't very interesting as it just passes the request onto the next middleware. To add your own middleware you should subclass `RawMiddleware` and override the `process_record` method. ```python # ... -class SizeLimiterMiddleware(RawRecordMiddleware): +class SizeLimiterMiddleware(RawMiddleware): def __init__(self, max_size: int): super().__init__() self.max_size = max_size @@ -114,15 +115,13 @@ class SizeLimiterMiddleware(RawRecordMiddleware): return response -server.register_raw_middleware(SizeLimiterMiddleware(1400)) +server.register_middleware(SizeLimiterMiddleware(1400)) ``` -### Default `RawRecordMiddleware` stack +### Default `RawMiddleware` stack -Once processed the `RawRecordMiddleware` stack will look as follows: +Once processed the `RawMiddleware` stack will look as follows: -- [`RawRecordExceptionHandlerMiddleware`][nserver.middleware.RawRecordExceptionHandlerMiddleware] +- [`RawExceptionHandlerMiddleware`][nserver.middleware.RawExceptionHandlerMiddleware] - Customisable error handler for `Exception`s originating from within the stack. - `` -- [`QueryMiddlewareProcessor`][nserver.middleware.QueryMiddlewareProcessor] - - entry point into the `QueryMiddleware` stack. diff --git a/docs/quickstart.md b/docs/quickstart.md index 25d25fc..fa78f54 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -19,9 +19,6 @@ server = NameServer("example") @server.rule("example.com", ["A"]) def example_a_records(query: Query): return A(query.name, "1.2.3.4") - -if __name__ == "__main__": - server.run() ``` Here's what this code does: @@ -37,28 +34,25 @@ Here's what this code does: 4. When triggered our function will then return a single `A` record as a response. -5. Finally we add code so that we can run our server. - ### Running our server -With our server written we can now run it: +With our server written we can now run it using the `nserver` CLI: -```shell -python3 example_server.py +```bash +nserver --server path/to/minimal_server.py ``` - ```{.none .no-copy} -[INFO] Starting UDPv4Transport(address='localhost', port=9953) +[INFO] Starting UDPv4Transport(address='localhost', port=5300) ``` We can access it using `dig`. ```shell -dig -p 9953 @localhost A example.com +dig -p 5300 @localhost A example.com ``` ```{.none .no-copy} -; <<>> DiG 9.18.12-0ubuntu0.22.04.3-Ubuntu <<>> -p 9953 @localhost A example.com +; <<>> DiG 9.18.12-0ubuntu0.22.04.3-Ubuntu <<>> -p 5300 @localhost A example.com ; (1 server found) ;; global options: +cmd ;; Got answer: @@ -72,7 +66,7 @@ dig -p 9953 @localhost A example.com example.com. 300 IN A 1.2.3.4 ;; Query time: 324 msec -;; SERVER: 127.0.0.1#9953(localhost) (UDP) +;; SERVER: 127.0.0.1#5300(localhost) (UDP) ;; WHEN: Thu Nov 02 21:27:12 AEDT 2023 ;; MSG SIZE rcvd: 45 ``` diff --git a/docs/subserver-blueprint.md b/docs/subserver-blueprint.md index c880566..774fc45 100644 --- a/docs/subserver-blueprint.md +++ b/docs/subserver-blueprint.md @@ -3,7 +3,7 @@ ## Sub-Servers -[`SubServer`][nserver.server.SubServer] provides a way for you to compose your application. They support most of the same functionality as a `NameServer`. +To allow for composing an application into different parts, a [`NameServer`][nserver.server.NameServer] can be included in another `NameServer`. Use cases: @@ -12,15 +12,15 @@ Use cases: - Allow custom packages to define their own rules that you can add to your own server. !!! note - SubServers requires `nserver>=2.0` + Adding a `NameServer` to another requires `nserver>=3.0` ### Using Sub-Servers ```python -from nserver import SubServer, NameServer, ZoneRule, ALL_CTYPES, A, TXT +from nserver import NameServer, ZoneRule, ALL_CTYPES, A, TXT -# First SubServer -mysite = SubServer("mysite") +# First child NameServer +mysite = NameServer("mysite") @mysite.rule("nicholashairs.com", ["A"]) @mysite.rule("www.nicholashairs.com", ["A"]) @@ -32,14 +32,14 @@ def nicholashairs_catchall(query: Query) -> None: # Return empty response for all other queries return None -# Second SubServer -en_subserver = SubServer("english-speaking-blueprint") +# Second child NameServer +en_subserver = NameServer("english-speaking-blueprint") @en_subserver.rule("hello.{base_domain}", ["TXT"]) def en_hello(query: Query) -> TXT: return TXT(query.name, "Hello There!") -# Register to NameServer +# Register to main NameServer server = NameServer("server") server.register_subserver(mysite, ZoneRule, "nicholashairs.com", ALL_CTYPES) server.register_subserver(en_subserver, ZoneRule, "au", ALL_CTYPES) @@ -47,22 +47,18 @@ server.register_subserver(en_subserver, ZoneRule, "nz", ALL_CTYPES) server.register_subserver(en_subserver, ZoneRule, "uk", ALL_CTYPES) ``` -#### Middleware, Hooks, and Error Handling +#### Middleware, Hooks, and Exception Handling -Sub-Servers maintain their own `QueryMiddleware` stack which will run before any rule function is run. Included in this stack is the `HookMiddleware` and `ExceptionHandlerMiddleware`. +Don't forget that each `NameServer` maintains it's own middleware stack, exception handlers, and hooks. -### Key differences with `NameServer` - -- Does not use settings (`Setting`). -- Does not have a `Transport`. -- Does not have a `RawRecordMiddleware` stack. +In particular errors will not propagate up from a child server to it's parent as the child's exception handler will catch any exception and return a response. ## Blueprints [`Blueprint`][nserver.server.Blueprint]s act as a container for rules. They are an efficient way to compose your application if you do not want or need to use functionality provided by a `QueryMiddleware` stack. !!! note - Blueprints require `nserver>=2.0` + Blueprints require `nserver>=3.0` ### Using Blueprints @@ -88,7 +84,7 @@ en_subserver.register_rule(no_email_blueprint) mysite.rules.insert(0, no_email_blueprint) ``` -### Key differences with `NameServer` and `SubServer` +### Key differences with `NameServer` - Only provides the `@rule` decorator and `register_rule` method. - It does not have a `QueryMiddleware` stack which means it does not support hooks or error-handling.