Skip to content

Commit

Permalink
Database migration class added (part of #214) (#244)
Browse files Browse the repository at this point in the history
* [CveXplore-240] relates to #243

* [CveXplore-240] started with db_connection abstract class and defining general properties

* [CveXplore-240] renamed database_schema to database_version

* [CveXplore-240] minor rename actions

* [CveXplore-240] Created class for database migrations

* [CveXplore-240] Created database models folder
  • Loading branch information
P-T-I authored Dec 20, 2023
1 parent 129cbba commit a997398
Show file tree
Hide file tree
Showing 21 changed files with 249 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CveXplore/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.3.20.dev7
0.3.20.dev11
3 changes: 3 additions & 0 deletions CveXplore/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ class Configuration(object):

CPE_FILTER_DEPRECATED = getenv_bool("CPE_FILTER_DEPRECATED", "True")

# Which datasource to query.Currently supported options include:
# - mongodb
# - api
DATASOURCE = os.getenv("DATASOURCE", "mongodb")

DATASOURCE_PROTOCOL = os.getenv("DATASOURCE_PROTOCOL", "mongodb")
Expand Down
2 changes: 1 addition & 1 deletion CveXplore/core/database_indexer/db_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, datasource):
super().__init__(__name__)

database = datasource
self.database = database._dbclient
self.database = database.dbclient

self.indexes = {
"cpe": [
Expand Down
8 changes: 4 additions & 4 deletions CveXplore/core/database_maintenance/download_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class DownloadHandler(ABC):
"""

def __init__(self, feed_type: str, logger_name: str, prefix: str = None):
self.config = Configuration()

self._end = None

self.feed_type = feed_type
Expand All @@ -65,16 +67,14 @@ def __init__(self, feed_type: str, logger_name: str, prefix: str = None):
self.do_process = True

database = DatabaseConnection(
database_type=os.getenv("DATASOURCE_TYPE"),
database_type=self.config.DATASOURCE,
database_init_parameters=json.loads(os.getenv("DATASOURCE_CON_DETAILS")),
).database_connection

self.database = database._dbclient
self.database = database.dbclient

self.database_indexer = DatabaseIndexer(datasource=database)

self.config = Configuration()

self.logger = logging.getLogger(logger_name)

self.logger.removeHandler(self.logger.handlers[0])
Expand Down
4 changes: 2 additions & 2 deletions CveXplore/core/database_maintenance/main_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
EPSSDownloads,
)
from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass
from CveXplore.core.database_schema.db_schema_checker import SchemaChecker
from CveXplore.core.database_version.db_version_checker import DatabaseVersionChecker
from CveXplore.core.logging.logger_class import AppLogger
from CveXplore.errors import UpdateSourceNotFound

Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(self, datasource):
]

self.database_indexer = DatabaseIndexer(datasource=datasource)
self.schema_checker = SchemaChecker(datasource=datasource)
self.schema_checker = DatabaseVersionChecker(datasource=datasource)

def validate_schema(self):
return self.schema_checker.validate_schema()
Expand Down
File renamed without changes.
169 changes: 169 additions & 0 deletions CveXplore/core/database_migration/database_migrator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import argparse
import logging
import os
import sys
from collections import namedtuple
from subprocess import run, PIPE, STDOUT, CompletedProcess

from CveXplore.core.logging.logger_class import AppLogger

logging.setLoggerClass(AppLogger)


class DatabaseMigrator(object):
def __init__(self, cwd: str = None):
self.logger = logging.getLogger(__name__)

self.current_dir = (
cwd
if cwd is not None
else os.path.dirname(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
)
)
)

self._commands = namedtuple(
"commands", "INIT REVISION UPGRADE CURRENT HISTORY REV_UP REV_DOWN"
)(1, 2, 3, 4, 5, 6, 7)

@property
def commands(self) -> namedtuple:
return self._commands

def db_init(self) -> None:
res = self.__cli_runner(self.commands.INIT)
self.__parse_command_output(res)

def db_revision(self, message: str) -> None:
res = self.__cli_runner(self.commands.REVISION, message=message)
self.__parse_command_output(res)

def db_upgrade(self) -> None:
res = self.__cli_runner(self.commands.UPGRADE)
self.__parse_command_output(res)

def db_current(self) -> None:
res = self.__cli_runner(self.commands.CURRENT)
self.__parse_command_output(res)

def db_history(self) -> None:
res = self.__cli_runner(self.commands.HISTORY)
self.__parse_command_output(res)

def db_up(self, count: int) -> None:
res = self.__cli_runner(self.commands.REV_UP, message=count)
self.__parse_command_output(res)

def db_down(self, count: int) -> None:
res = self.__cli_runner(self.commands.REV_DOWN, message=count)
self.__parse_command_output(res)

def __parse_command_output(self, cmd_output: CompletedProcess) -> None:
if cmd_output.returncode != 0:
self.logger.error(cmd_output.stdout.split("\n")[0])
else:
output_list = cmd_output.stdout.split("\n")

for m in output_list:
if m != "":
self.logger.info(m)

def __cli_runner(self, command: int, message: str | int = None) -> CompletedProcess:
if command == 2 and message is None:
raise ValueError("Missing message for revision command")
elif command == 6 and message is None:
raise ValueError(
"You must specify a positive number when submitting a upgrade command"
)
elif command == 7 and message is None:
raise ValueError(
"You must specify a negative number when submitting a downgrade command"
)

command_mapping = {
1: f"alembic init alembic",
2: f"alembic revision -m {message}",
3: f"alembic upgrade head",
4: f"alembic current",
5: f"alembic history --verbose",
6: f"alembic upgrade {message}",
7: f"alembic downgrade {message}",
}
try:
result = run(
command_mapping[command], # nosec
stdout=PIPE,
stderr=STDOUT,
universal_newlines=True,
shell=True,
cwd=self.current_dir,
)
return result
except KeyError: # pragma: no cover
self.logger.error(f"Unknown command number received....")

def __repr__(self) -> str:
return f"<< {self.__class__.__name__} >>"


if __name__ == "__main__":
argparser = argparse.ArgumentParser(
description="migrate/update the database schema"
)

argparser.add_argument(
"-i", action="store_true", help="Setup new alembic environment"
)
argparser.add_argument(
"-r", action="store_true", help="Create new revision the database"
)
argparser.add_argument(
"-u", action="store_true", help="Update the database to latest head"
)

argparser.add_argument(
"-up", action="store_true", help="Upgrade the database x revisions"
)
argparser.add_argument(
"-down", action="store_true", help="Downgrade the database x revisions"
)

argparser.add_argument("-cs", action="store_true", help="Print current state")
argparser.add_argument("-hist", action="store_true", help="Print history")

args = argparser.parse_args()

fsm = DatabaseMigrator()

if (args_count := len(sys.argv)) < 2 and args.r:
print("You must specify a message when submitting a new revision")
raise SystemExit(2)
elif args_count < 2 and args.up:
print("You must specify a positive number when submitting a upgrade command")
raise SystemExit(2)
elif args_count < 2 and args.down:
print("You must specify a negative number when submitting a downgrade command")
raise SystemExit(2)

if args.i:
fsm.db_init()

if args.r:
fsm.db_revision(sys.argv[1])

if args.u:
fsm.db_upgrade()

if args.up:
fsm.db_up(int(sys.argv[1]))

if args.down:
fsm.db_down(int(sys.argv[1]))

if args.cs:
fsm.db_current()

if args.hist:
fsm.db_history()
Empty file.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@
import os

from CveXplore.core.database_maintenance.update_base_class import UpdateBaseClass
from CveXplore.errors import DatabaseSchemaError
from CveXplore.errors import DatabaseSchemaVersionError

runPath = os.path.dirname(os.path.realpath(__file__))


class SchemaChecker(UpdateBaseClass):
class DatabaseVersionChecker(UpdateBaseClass):
def __init__(self, datasource):
super().__init__(__name__)
with open(os.path.join(runPath, "../../.schema_version")) as f:
self.schema_version = json.loads(f.read())

database = datasource

self.dbh = database._dbclient["schema"]
self.dbh = database.dbclient["schema"]

def validate_schema(self):
try:
Expand All @@ -24,18 +24,18 @@ def validate_schema(self):
== list(self.dbh.find({}))[0]["version"]
):
if not self.schema_version["rebuild_needed"]:
raise DatabaseSchemaError(
raise DatabaseSchemaVersionError(
"Database is not on the latest schema version; please update the database!"
)
else:
raise DatabaseSchemaError(
raise DatabaseSchemaVersionError(
"Database schema is not up to date; please re-populate the database!"
)
else:
return True
except IndexError:
# something went wrong fetching the result from the database; assume re-populate is needed
raise DatabaseSchemaError(
raise DatabaseSchemaVersionError(
"Database schema is not up to date; please re-populate the database!"
)

Expand Down
1 change: 1 addition & 0 deletions CveXplore/core/general/datasources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
supported_datasources = {"mongodb", "api"}
8 changes: 7 additions & 1 deletion CveXplore/database/connection/base/db_connection_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import logging
from abc import ABC, abstractmethod

from CveXplore.core.logging.logger_class import AppLogger

logging.setLoggerClass(AppLogger)


class DatabaseConnectionBase(object):
class DatabaseConnectionBase(ABC):
def __init__(self, logger_name: str):
self.logger = logging.getLogger(logger_name)

def __repr__(self):
return f"<<{self.__class__.__name__}>>"

@property
@abstractmethod
def dbclient(self):
raise NotImplementedError
12 changes: 8 additions & 4 deletions CveXplore/database/connection/mongodb/mongo_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self._dbclient = self.client[database]

try:
collections = self._dbclient.list_collection_names()
collections = self.db_client.list_collection_names()
except ServerSelectionTimeoutError as err:
raise DatabaseConnectionException(
f"Connection to the database failed: {err}"
Expand All @@ -50,18 +50,22 @@ def __init__(
for each in collections:
self.__setattr__(
f"store_{each}",
CveSearchCollection(database=self._dbclient, name=each),
CveSearchCollection(database=self.db_client, name=each),
)

atexit.register(self.disconnect)

@property
def dbclient(self):
return self._dbclient

def set_handlers_for_collections(self):
for each in self._dbclient.list_collection_names():
for each in self.db_client.list_collection_names():
if not hasattr(self, each):
setattr(
self,
f"store_{each}",
CveSearchCollection(database=self._dbclient, name=each),
CveSearchCollection(database=self.db_client, name=each),
)

def disconnect(self):
Expand Down
Empty file.
12 changes: 12 additions & 0 deletions CveXplore/database/connection/sqlbase/sql_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase


class SQLBase(DatabaseConnectionBase):
def __init__(self):
super().__init__(logger_name=__name__)

self._dbclient = None

@property
def dbclient(self):
return self._dbclient
2 changes: 1 addition & 1 deletion CveXplore/errors/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ class UpdateSourceNotFound(DatabaseException):
pass


class DatabaseSchemaError(DatabaseException):
class DatabaseSchemaVersionError(DatabaseException):
pass
6 changes: 6 additions & 0 deletions CveXplore/errors/datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
class DatasourceException(Exception):
pass


class UnsupportedDatasourceException(DatasourceException):
pass
Loading

0 comments on commit a997398

Please sign in to comment.