diff --git a/CveXplore/VERSION b/CveXplore/VERSION index 9d9b8c3b..a0a1ce09 100644 --- a/CveXplore/VERSION +++ b/CveXplore/VERSION @@ -1 +1 @@ -0.3.20.dev7 \ No newline at end of file +0.3.20.dev11 \ No newline at end of file diff --git a/CveXplore/common/config.py b/CveXplore/common/config.py index 6bd05566..c53bb3b8 100644 --- a/CveXplore/common/config.py +++ b/CveXplore/common/config.py @@ -86,6 +86,9 @@ class Configuration(object): CPE_FILTER_DEPRECATED = getenv_bool("CPE_FILTER_DEPRECATED", "True") + # Which datasource to query.Currently supported options include: + # - mongodb + # - api DATASOURCE = os.getenv("DATASOURCE", "mongodb") DATASOURCE_PROTOCOL = os.getenv("DATASOURCE_PROTOCOL", "mongodb") diff --git a/CveXplore/core/database_indexer/db_indexer.py b/CveXplore/core/database_indexer/db_indexer.py index 712e563f..28e26883 100644 --- a/CveXplore/core/database_indexer/db_indexer.py +++ b/CveXplore/core/database_indexer/db_indexer.py @@ -18,7 +18,7 @@ def __init__(self, datasource): super().__init__(__name__) database = datasource - self.database = database._dbclient + self.database = database.dbclient self.indexes = { "cpe": [ diff --git a/CveXplore/core/database_maintenance/download_handler.py b/CveXplore/core/database_maintenance/download_handler.py index a292a15a..b29fb338 100644 --- a/CveXplore/core/database_maintenance/download_handler.py +++ b/CveXplore/core/database_maintenance/download_handler.py @@ -47,6 +47,8 @@ class DownloadHandler(ABC): """ def __init__(self, feed_type: str, logger_name: str, prefix: str = None): + self.config = Configuration() + self._end = None self.feed_type = feed_type @@ -65,16 +67,14 @@ def __init__(self, feed_type: str, logger_name: str, prefix: str = None): self.do_process = True database = DatabaseConnection( - database_type=os.getenv("DATASOURCE_TYPE"), + database_type=self.config.DATASOURCE, database_init_parameters=json.loads(os.getenv("DATASOURCE_CON_DETAILS")), ).database_connection - self.database = database._dbclient + self.database = database.dbclient self.database_indexer = DatabaseIndexer(datasource=database) - self.config = Configuration() - self.logger = logging.getLogger(logger_name) self.logger.removeHandler(self.logger.handlers[0]) diff --git a/CveXplore/core/database_maintenance/main_updater.py b/CveXplore/core/database_maintenance/main_updater.py index 5253e086..6497c20b 100644 --- a/CveXplore/core/database_maintenance/main_updater.py +++ b/CveXplore/core/database_maintenance/main_updater.py @@ -16,7 +16,7 @@ EPSSDownloads, ) from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass -from CveXplore.core.database_schema.db_schema_checker import SchemaChecker +from CveXplore.core.database_version.db_version_checker import DatabaseVersionChecker from CveXplore.core.logging.logger_class import AppLogger from CveXplore.errors import UpdateSourceNotFound @@ -46,7 +46,7 @@ def __init__(self, datasource): ] self.database_indexer = DatabaseIndexer(datasource=datasource) - self.schema_checker = SchemaChecker(datasource=datasource) + self.schema_checker = DatabaseVersionChecker(datasource=datasource) def validate_schema(self): return self.schema_checker.validate_schema() diff --git a/CveXplore/core/database_schema/__init__.py b/CveXplore/core/database_migration/__init__.py similarity index 100% rename from CveXplore/core/database_schema/__init__.py rename to CveXplore/core/database_migration/__init__.py diff --git a/CveXplore/core/database_migration/database_migrator.py b/CveXplore/core/database_migration/database_migrator.py new file mode 100644 index 00000000..10439e7d --- /dev/null +++ b/CveXplore/core/database_migration/database_migrator.py @@ -0,0 +1,169 @@ +import argparse +import logging +import os +import sys +from collections import namedtuple +from subprocess import run, PIPE, STDOUT, CompletedProcess + +from CveXplore.core.logging.logger_class import AppLogger + +logging.setLoggerClass(AppLogger) + + +class DatabaseMigrator(object): + def __init__(self, cwd: str = None): + self.logger = logging.getLogger(__name__) + + self.current_dir = ( + cwd + if cwd is not None + else os.path.dirname( + os.path.dirname( + os.path.dirname(os.path.dirname(os.path.realpath(__file__))) + ) + ) + ) + + self._commands = namedtuple( + "commands", "INIT REVISION UPGRADE CURRENT HISTORY REV_UP REV_DOWN" + )(1, 2, 3, 4, 5, 6, 7) + + @property + def commands(self) -> namedtuple: + return self._commands + + def db_init(self) -> None: + res = self.__cli_runner(self.commands.INIT) + self.__parse_command_output(res) + + def db_revision(self, message: str) -> None: + res = self.__cli_runner(self.commands.REVISION, message=message) + self.__parse_command_output(res) + + def db_upgrade(self) -> None: + res = self.__cli_runner(self.commands.UPGRADE) + self.__parse_command_output(res) + + def db_current(self) -> None: + res = self.__cli_runner(self.commands.CURRENT) + self.__parse_command_output(res) + + def db_history(self) -> None: + res = self.__cli_runner(self.commands.HISTORY) + self.__parse_command_output(res) + + def db_up(self, count: int) -> None: + res = self.__cli_runner(self.commands.REV_UP, message=count) + self.__parse_command_output(res) + + def db_down(self, count: int) -> None: + res = self.__cli_runner(self.commands.REV_DOWN, message=count) + self.__parse_command_output(res) + + def __parse_command_output(self, cmd_output: CompletedProcess) -> None: + if cmd_output.returncode != 0: + self.logger.error(cmd_output.stdout.split("\n")[0]) + else: + output_list = cmd_output.stdout.split("\n") + + for m in output_list: + if m != "": + self.logger.info(m) + + def __cli_runner(self, command: int, message: str | int = None) -> CompletedProcess: + if command == 2 and message is None: + raise ValueError("Missing message for revision command") + elif command == 6 and message is None: + raise ValueError( + "You must specify a positive number when submitting a upgrade command" + ) + elif command == 7 and message is None: + raise ValueError( + "You must specify a negative number when submitting a downgrade command" + ) + + command_mapping = { + 1: f"alembic init alembic", + 2: f"alembic revision -m {message}", + 3: f"alembic upgrade head", + 4: f"alembic current", + 5: f"alembic history --verbose", + 6: f"alembic upgrade {message}", + 7: f"alembic downgrade {message}", + } + try: + result = run( + command_mapping[command], # nosec + stdout=PIPE, + stderr=STDOUT, + universal_newlines=True, + shell=True, + cwd=self.current_dir, + ) + return result + except KeyError: # pragma: no cover + self.logger.error(f"Unknown command number received....") + + def __repr__(self) -> str: + return f"<< {self.__class__.__name__} >>" + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser( + description="migrate/update the database schema" + ) + + argparser.add_argument( + "-i", action="store_true", help="Setup new alembic environment" + ) + argparser.add_argument( + "-r", action="store_true", help="Create new revision the database" + ) + argparser.add_argument( + "-u", action="store_true", help="Update the database to latest head" + ) + + argparser.add_argument( + "-up", action="store_true", help="Upgrade the database x revisions" + ) + argparser.add_argument( + "-down", action="store_true", help="Downgrade the database x revisions" + ) + + argparser.add_argument("-cs", action="store_true", help="Print current state") + argparser.add_argument("-hist", action="store_true", help="Print history") + + args = argparser.parse_args() + + fsm = DatabaseMigrator() + + if (args_count := len(sys.argv)) < 2 and args.r: + print("You must specify a message when submitting a new revision") + raise SystemExit(2) + elif args_count < 2 and args.up: + print("You must specify a positive number when submitting a upgrade command") + raise SystemExit(2) + elif args_count < 2 and args.down: + print("You must specify a negative number when submitting a downgrade command") + raise SystemExit(2) + + if args.i: + fsm.db_init() + + if args.r: + fsm.db_revision(sys.argv[1]) + + if args.u: + fsm.db_upgrade() + + if args.up: + fsm.db_up(int(sys.argv[1])) + + if args.down: + fsm.db_down(int(sys.argv[1])) + + if args.cs: + fsm.db_current() + + if args.hist: + fsm.db_history() diff --git a/CveXplore/core/database_models/__init__.py b/CveXplore/core/database_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CveXplore/core/database_models/models.py b/CveXplore/core/database_models/models.py new file mode 100644 index 00000000..e69de29b diff --git a/CveXplore/core/database_version/__init__.py b/CveXplore/core/database_version/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CveXplore/core/database_schema/db_schema_checker.py b/CveXplore/core/database_version/db_version_checker.py similarity index 88% rename from CveXplore/core/database_schema/db_schema_checker.py rename to CveXplore/core/database_version/db_version_checker.py index 53a27743..0cdfa231 100644 --- a/CveXplore/core/database_schema/db_schema_checker.py +++ b/CveXplore/core/database_version/db_version_checker.py @@ -2,12 +2,12 @@ import os from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass -from CveXplore.errors import DatabaseSchemaError +from CveXplore.errors import DatabaseSchemaVersionError runPath = os.path.dirname(os.path.realpath(__file__)) -class SchemaChecker(UpdateBaseClass): +class DatabaseVersionChecker(UpdateBaseClass): def __init__(self, datasource): super().__init__(__name__) with open(os.path.join(runPath, "../../.schema_version")) as f: @@ -15,7 +15,7 @@ def __init__(self, datasource): database = datasource - self.dbh = database._dbclient["schema"] + self.dbh = database.dbclient["schema"] def validate_schema(self): try: @@ -24,18 +24,18 @@ def validate_schema(self): == list(self.dbh.find({}))[0]["version"] ): if not self.schema_version["rebuild_needed"]: - raise DatabaseSchemaError( + raise DatabaseSchemaVersionError( "Database is not on the latest schema version; please update the database!" ) else: - raise DatabaseSchemaError( + raise DatabaseSchemaVersionError( "Database schema is not up to date; please re-populate the database!" ) else: return True except IndexError: # something went wrong fetching the result from the database; assume re-populate is needed - raise DatabaseSchemaError( + raise DatabaseSchemaVersionError( "Database schema is not up to date; please re-populate the database!" ) diff --git a/CveXplore/core/general/datasources.py b/CveXplore/core/general/datasources.py new file mode 100644 index 00000000..98a8f34e --- /dev/null +++ b/CveXplore/core/general/datasources.py @@ -0,0 +1 @@ +supported_datasources = {"mongodb", "api"} diff --git a/CveXplore/database/connection/base/db_connection_base.py b/CveXplore/database/connection/base/db_connection_base.py index acd7f778..c02bfe77 100644 --- a/CveXplore/database/connection/base/db_connection_base.py +++ b/CveXplore/database/connection/base/db_connection_base.py @@ -1,13 +1,19 @@ import logging +from abc import ABC, abstractmethod from CveXplore.core.logging.logger_class import AppLogger logging.setLoggerClass(AppLogger) -class DatabaseConnectionBase(object): +class DatabaseConnectionBase(ABC): def __init__(self, logger_name: str): self.logger = logging.getLogger(logger_name) def __repr__(self): return f"<<{self.__class__.__name__}>>" + + @property + @abstractmethod + def dbclient(self): + raise NotImplementedError diff --git a/CveXplore/database/connection/mongodb/mongo_db.py b/CveXplore/database/connection/mongodb/mongo_db.py index 47a1cfea..d8956b9f 100644 --- a/CveXplore/database/connection/mongodb/mongo_db.py +++ b/CveXplore/database/connection/mongodb/mongo_db.py @@ -40,7 +40,7 @@ def __init__( self._dbclient = self.client[database] try: - collections = self._dbclient.list_collection_names() + collections = self.db_client.list_collection_names() except ServerSelectionTimeoutError as err: raise DatabaseConnectionException( f"Connection to the database failed: {err}" @@ -50,18 +50,22 @@ def __init__( for each in collections: self.__setattr__( f"store_{each}", - CveSearchCollection(database=self._dbclient, name=each), + CveSearchCollection(database=self.db_client, name=each), ) atexit.register(self.disconnect) + @property + def dbclient(self): + return self._dbclient + def set_handlers_for_collections(self): - for each in self._dbclient.list_collection_names(): + for each in self.db_client.list_collection_names(): if not hasattr(self, each): setattr( self, f"store_{each}", - CveSearchCollection(database=self._dbclient, name=each), + CveSearchCollection(database=self.db_client, name=each), ) def disconnect(self): diff --git a/CveXplore/database/connection/sqlbase/__init__.py b/CveXplore/database/connection/sqlbase/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CveXplore/database/connection/sqlbase/sql_base.py b/CveXplore/database/connection/sqlbase/sql_base.py new file mode 100644 index 00000000..44047dc0 --- /dev/null +++ b/CveXplore/database/connection/sqlbase/sql_base.py @@ -0,0 +1,12 @@ +from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase + + +class SQLBase(DatabaseConnectionBase): + def __init__(self): + super().__init__(logger_name=__name__) + + self._dbclient = None + + @property + def dbclient(self): + return self._dbclient diff --git a/CveXplore/errors/database.py b/CveXplore/errors/database.py index ebb4dbff..081e8623 100644 --- a/CveXplore/errors/database.py +++ b/CveXplore/errors/database.py @@ -18,5 +18,5 @@ class UpdateSourceNotFound(DatabaseException): pass -class DatabaseSchemaError(DatabaseException): +class DatabaseSchemaVersionError(DatabaseException): pass diff --git a/CveXplore/errors/datasource.py b/CveXplore/errors/datasource.py new file mode 100644 index 00000000..39a26534 --- /dev/null +++ b/CveXplore/errors/datasource.py @@ -0,0 +1,6 @@ +class DatasourceException(Exception): + pass + + +class UnsupportedDatasourceException(DatasourceException): + pass diff --git a/CveXplore/main.py b/CveXplore/main.py index 64ca0278..b1675b2c 100644 --- a/CveXplore/main.py +++ b/CveXplore/main.py @@ -18,9 +18,12 @@ from CveXplore.common.cpe_converters import create_cpe_regex_string from CveXplore.common.db_mapping import database_mapping from CveXplore.core.database_maintenance.main_updater import MainUpdater +from CveXplore.core.database_migration.database_migrator import DatabaseMigrator +from CveXplore.core.general.datasources import supported_datasources from CveXplore.database.connection.database_connection import DatabaseConnection from CveXplore.database.connection.mongodb.mongo_db import MongoDBConnection from CveXplore.errors import DatabaseIllegalCollection +from CveXplore.errors.datasource import UnsupportedDatasourceException from CveXplore.errors.validation import CveNumberValidationError from CveXplore.objects.cvexplore_object import CveXploreObject @@ -40,16 +43,14 @@ class CveXplore(object): def __init__( self, - datasource_type: str = "mongodb", + datasource_type: str = None, datasource_connection_details: dict = None, mongodb_connection_details: dict = None, api_connection_details: dict = None, ): """ Create a new instance of CveXplore - :param datasource_type: Which datasource to query. Currently supported options include: - - mongodb - - api + :param datasource_type: Which datasource to query. :param datasource_connection_details: Provide the connection details needed to establish a connection to the datasource. The connection details should be in line with the datasource it's documentation. @@ -66,7 +67,9 @@ def __init__( self.config = Configuration() self.logger = logging.getLogger(__name__) - self._datasource_type = datasource_type + self.datasource_type = ( + datasource_type if datasource_type is not None else self.config.DATASOURCE + ) self._datasource_connection_details = datasource_connection_details self._mongodb_connection_details = mongodb_connection_details @@ -74,6 +77,20 @@ def __init__( os.environ["DOC_BUILD"] = json.dumps({"DOC_BUILD": "NO"}) + self.logger.info( + f"Using {self.datasource_type} as datasource, connection details: {self.datasource_connection_details}" + ) + + if self.datasource_type not in supported_datasources: + raise UnsupportedDatasourceException( + f"Unsupported datasource selected: '{self.datasource_type}'; currently supported: {supported_datasources}" + ) + + if self.datasource_type == "api" and self.datasource_connection_details is None: + raise ValueError( + "Missing datasource_connection_details for selected datasource ('api')" + ) + if self.mongodb_connection_details is not None: self.logger.warning( "The use of mongodb_connection_details is deprecated and will be removed in the 0.4 release, please " @@ -108,6 +125,8 @@ def __init__( ).database_connection self.database = MainUpdater(datasource=self.datasource) + self.database_migrator = DatabaseMigrator() + self._database_mapping = database_mapping from CveXplore.database.helpers.specific_db import ( @@ -124,10 +143,6 @@ def __init__( self.logger.info(f"Initialized CveXplore version: {self.version}") - @property - def datasource_type(self): - return self._datasource_type - @property def datasource_connection_details(self): return self._datasource_connection_details diff --git a/CveXplore/objects/cvexplore_object.py b/CveXplore/objects/cvexplore_object.py index 7dbe7d79..440bbe80 100644 --- a/CveXplore/objects/cvexplore_object.py +++ b/CveXplore/objects/cvexplore_object.py @@ -2,6 +2,7 @@ CveXploreObject =============== """ +from CveXplore.common.config import Configuration class CveXploreObject(object): @@ -10,7 +11,7 @@ class CveXploreObject(object): """ def __init__(self): - pass + self.config = Configuration() def __repr__(self) -> str: return f"<< {self.__class__.__name__} >>" diff --git a/requirements/modules/sqlalchemy.txt b/requirements/modules/sqlalchemy.txt index e69de29b..7c4226e3 100644 --- a/requirements/modules/sqlalchemy.txt +++ b/requirements/modules/sqlalchemy.txt @@ -0,0 +1,2 @@ +sqlalchemy>=2.0.23 +alembic>=1.13.0