From d62672e2b0d39d93e9a8a6908a946d6b058cbcd8 Mon Sep 17 00:00:00 2001 From: Paul Tikken Date: Tue, 23 Apr 2024 13:03:23 +0000 Subject: [PATCH] refactoring errors on imports --- .../connection/base/db_connection_base.py | 2 + .../connection/database_connection.py | 1 + .../database/connection/sqlbase/connection.py | 5 +- .../database/connection/sqlbase/sql_base.py | 52 ++++++++++--------- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/CveXplore/database/connection/base/db_connection_base.py b/CveXplore/database/connection/base/db_connection_base.py index e58068e5..bd40eabe 100644 --- a/CveXplore/database/connection/base/db_connection_base.py +++ b/CveXplore/database/connection/base/db_connection_base.py @@ -1,6 +1,7 @@ import logging from abc import ABC, abstractmethod +from CveXplore.common.config import Configuration from CveXplore.core.logging.logger_class import AppLogger logging.setLoggerClass(AppLogger) @@ -9,6 +10,7 @@ class DatabaseConnectionBase(ABC): def __init__(self, logger_name: str): self.logger = logging.getLogger(logger_name) + self.config = Configuration def __repr__(self): return f"<<{self.__class__.__name__}>>" diff --git a/CveXplore/database/connection/database_connection.py b/CveXplore/database/connection/database_connection.py index 90de5598..84960b6d 100644 --- a/CveXplore/database/connection/database_connection.py +++ b/CveXplore/database/connection/database_connection.py @@ -7,6 +7,7 @@ class DatabaseConnection(object): def __init__(self, database_type: str, database_init_parameters: dict): + self.database_type = database_type self.database_init_parameters = database_init_parameters diff --git a/CveXplore/database/connection/sqlbase/connection.py b/CveXplore/database/connection/sqlbase/connection.py index 8b891ffd..4bcc88ce 100644 --- a/CveXplore/database/connection/sqlbase/connection.py +++ b/CveXplore/database/connection/sqlbase/connection.py @@ -5,6 +5,7 @@ config = Configuration -engine = create_engine(config.SQLALCHEMY_DATABASE_URI, echo=True) +if config.DATASOURCE_TYPE != "mongodb": + engine = create_engine(config.SQLALCHEMY_DATABASE_URI, echo=True) -Session = sessionmaker(bind=engine) + Session = sessionmaker(bind=engine) diff --git a/CveXplore/database/connection/sqlbase/sql_base.py b/CveXplore/database/connection/sqlbase/sql_base.py index 97b0c099..02b2e580 100644 --- a/CveXplore/database/connection/sqlbase/sql_base.py +++ b/CveXplore/database/connection/sqlbase/sql_base.py @@ -1,5 +1,4 @@ from CveXplore.database.connection.base.db_connection_base import DatabaseConnectionBase -from CveXplore.database.connection.sqlbase.sql_client import SQLClient from CveXplore.database_models.models import CveXploreBase from CveXplore.errors import DatabaseConnectionException @@ -7,33 +6,38 @@ class SQLBaseConnection(DatabaseConnectionBase): def __init__(self, **kwargs): super().__init__(logger_name=__name__) + if self.config.DATASOURCE_TYPE != "mongodb": + from CveXplore.database.connection.sqlbase.sql_client import SQLClient - self._dbclient = { - "info": SQLClient("info"), - "cpe": SQLClient("cpe"), - "cves": SQLClient("cves"), - "schema": SQLClient("schema"), - "cwe": SQLClient("cwe"), - "capec": SQLClient("capec"), - "via4": SQLClient("via4"), - } - - try: - collections = list(CveXploreBase.metadata.tables.keys()) - except ConnectionError as err: - raise DatabaseConnectionException( - f"Connection to the database failed: {err}" - ) - - if len(collections) != 0: - for each in collections: - self.__setattr__(f"store_{each}", SQLClient(each)) + self._dbclient = { + "info": SQLClient("info"), + "cpe": SQLClient("cpe"), + "cves": SQLClient("cves"), + "schema": SQLClient("schema"), + "cwe": SQLClient("cwe"), + "capec": SQLClient("capec"), + "via4": SQLClient("via4"), + } + + try: + collections = list(CveXploreBase.metadata.tables.keys()) + except ConnectionError as err: + raise DatabaseConnectionException( + f"Connection to the database failed: {err}" + ) + + if len(collections) != 0: + for each in collections: + self.__setattr__(f"store_{each}", SQLClient(each)) @property def dbclient(self): return self._dbclient def set_handlers_for_collections(self): - for each in list(CveXploreBase.metadata.tables.keys()): - if not hasattr(self, each): - setattr(self, f"store_{each}", SQLClient(each)) + if self.config.DATASOURCE_TYPE != "mongodb": + from CveXplore.database.connection.sqlbase.sql_client import SQLClient + + for each in list(CveXploreBase.metadata.tables.keys()): + if not hasattr(self, each): + setattr(self, f"store_{each}", SQLClient(each))