diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6793413..a48e036 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -26,8 +26,8 @@ jobs: - name: Lint run: | - pylint app tests main.py setup.py --output pylint-report.txt pylint app tests main.py setup.py + pylint app tests main.py setup.py --output pylint-report.txt - name: Testing run: | diff --git a/app/connect_databricks.py b/app/connect_databricks.py index 74a1f98..3347774 100644 --- a/app/connect_databricks.py +++ b/app/connect_databricks.py @@ -12,6 +12,9 @@ def init_databricks(): return spark, dbutils +def get_param_value(dbutils, param_key): + return dbutils.widgets.get(param_key) + def create_mount(dbutils, container_name, mount_path): storage_name = os.environ["storage_account_name"] diff --git a/app/Extraction.py b/app/extraction.py similarity index 100% rename from app/Extraction.py rename to app/extraction.py diff --git a/app/SparkWrapper.py b/app/spark_wrapper.py similarity index 84% rename from app/SparkWrapper.py rename to app/spark_wrapper.py index 4816291..d28b9e6 100644 --- a/app/SparkWrapper.py +++ b/app/spark_wrapper.py @@ -1,10 +1,8 @@ -from pyspark.sql import SparkSession, DataFrame +from pyspark.sql import DataFrame from pyspark.sql import Window, WindowSpec -def create_frame(sc: SparkSession | None, path: str): - if sc is None: - raise TypeError(f"{sc} is None. Pass Spark Session") +def create_frame(sc, path: str): df = sc.read.csv(path, inferSchema=True, header=True) return df diff --git a/main.py b/main.py index 0493b56..b65cf8f 100644 --- a/main.py +++ b/main.py @@ -7,48 +7,54 @@ from pyspark.sql import Window import pyspark.sql.functions as F -import app.SparkWrapper as sw +import app.spark_wrapper as sw os.system("pip install python-dotenv") import dotenv # pylint: disable=wrong-import-position, disable=wrong-import-order # COMMAND ---------- -try: - import app.connect_databricks as cd # pylint: disable=ungrouped-imports - import json +# try: +# import app.connect_databricks as cd # pylint: disable=ungrouped-imports +# import json - # Comment the following line if running directly in cloud notebook - spark, dbutils = cd.init_databricks() +# # Comment the following line if running directly in cloud notebook +# spark, dbutils = cd.init_databricks() - with open("/dbfs/mnt/config/keys.json", encoding="utf-8") as file: - keys = json.load(file) +# with open("/dbfs/mnt/config/keys.json", encoding="utf-8") as file: +# keys = json.load(file) - flag = keys["flag"] -except: # pylint: disable=bare-except - flag = "False" +# flag = keys["flag"] +# except: # pylint: disable=bare-except +# flag = "False" -flag = bool(flag == "True") +# flag = bool(flag == "True") -# if 'spark' in locals(): -# flag = True -# else: -# spark = None -# dbutils = None -# flag = False +if "dbutils" in locals(): + flag = True +else: + spark = None + dbutils = None + flag = False # COMMAND ---------- if flag: - os.environ["KAGGLE_USERNAME"] = dbutils.widgets.get("kaggle_username") + import app.connect_databricks as cd - os.environ["KAGGLE_KEY"] = dbutils.widgets.get("kaggle_token") + os.environ["KAGGLE_USERNAME"] = cd.get_param_value(dbutils, "kaggle_username") - os.environ["storage_account_name"] = dbutils.widgets.get("storage_account_name") + os.environ["KAGGLE_KEY"] = cd.get_param_value(dbutils, "kaggle_token") - os.environ["datalake_access_key"] = dbutils.widgets.get("datalake_access_key") + os.environ["storage_account_name"] = cd.get_param_value( + dbutils, "storage_account_name" + ) + + os.environ["datalake_access_key"] = cd.get_param_value( + dbutils, "datalake_access_key" + ) # COMMAND ---------- @@ -84,7 +90,7 @@ # COMMAND ---------- -from app.Extraction import extract_from_kaggle # pylint: disable=wrong-import-position +from app.extraction import extract_from_kaggle # pylint: disable=wrong-import-position # COMMAND ---------- diff --git a/tests/connect_glue_test.py b/tests/connect_glue_test.py deleted file mode 100644 index 16a837a..0000000 --- a/tests/connect_glue_test.py +++ /dev/null @@ -1,52 +0,0 @@ -import unittest -from unittest.mock import patch, MagicMock -from app.connect_glue import init_glue - - -class TestInitGlue(unittest.TestCase): - @patch("app.connect_glue.SparkContext") - @patch("app.connect_glue.GlueContext") - @patch("app.connect_glue.Job") - def test_init_glue(self, mock_job, mock_glue_context, mock_spark_context): - # Mock the SparkContext, GlueContext, and Job - mock_spark_context_instance = MagicMock() - mock_glue_context_instance = MagicMock() - mock_job_instance = MagicMock() - - # Set up the behavior of the mock instances - mock_spark_context.return_value = mock_spark_context_instance - mock_glue_context.return_value = mock_glue_context_instance - mock_job.return_value = mock_job_instance - - # Call the function to test - glue_context, spark, job = init_glue() - - # Assertions - mock_spark_context.assert_called_once() - mock_glue_context.assert_called_once_with(mock_spark_context_instance) - mock_job.assert_called_once_with(mock_glue_context_instance) - - # Check if the returned values are correct - self.assertEqual(glue_context, mock_glue_context_instance) - self.assertEqual(spark, mock_glue_context_instance.spark_session) - self.assertEqual(job, mock_job_instance) - - @patch("app.connect_glue.SparkContext") - @patch("app.connect_glue.GlueContext") - @patch("app.connect_glue.Job") - def test_init_glue_failure(self, mock_job, mock_glue_context, mock_spark_context): - # Simulate a ValueError during SparkContext initialization - error_statement = "Simulated SparkContext initialization failure" - mock_spark_context.side_effect = ValueError(error_statement) - - # Call the function to test - with self.assertRaises(ValueError) as context: - init_glue() - - # Assertions - mock_spark_context.assert_called_once() - mock_glue_context.assert_not_called() # GlueContext should not be called if SparkContext initialization fails - mock_job.assert_not_called() # Job should not be called if SparkContext initialization fails - - # Check if the error displayed correctly - self.assertEqual(str(context.exception), error_statement) diff --git a/tests/test_Extraction.py b/tests/test_extraction.py similarity index 94% rename from tests/test_Extraction.py rename to tests/test_extraction.py index 601d30d..eefb4e2 100644 --- a/tests/test_Extraction.py +++ b/tests/test_extraction.py @@ -1,10 +1,10 @@ import unittest from unittest.mock import patch -from app.Extraction import extract_from_kaggle +from app.extraction import extract_from_kaggle class TestExtraction(unittest.TestCase): - @patch("app.Extraction.kaggle") + @patch("app.extraction.kaggle") def test_extract_from_kaggle_success(self, mock_kaggle): mock_kaggle_instance = mock_kaggle mock_api_instance = mock_kaggle_instance.KaggleApi.return_value @@ -26,7 +26,7 @@ def test_extract_from_kaggle_success(self, mock_kaggle): self.assertEqual(result, ("/mnt/rawdata/", "/mnt/transformed/")) - @patch("app.Extraction.kaggle") + @patch("app.extraction.kaggle") def test_extract_from_kaggle_success_false(self, mock_kaggle): mock_kaggle_instance = mock_kaggle mock_api_instance = mock_kaggle_instance.KaggleApi.return_value @@ -52,7 +52,7 @@ def test_extract_from_kaggle_success_false(self, mock_kaggle): ), ) - @patch("app.Extraction.kaggle") + @patch("app.extraction.kaggle") def test_extract_from_kaggle_failure(self, mock_kaggle): mock_kaggle_instance = mock_kaggle mock_api_instance = mock_kaggle_instance.KaggleApi.return_value diff --git a/tests/test_SparkWrapper.py b/tests/test_spark_wrapper.py similarity index 98% rename from tests/test_SparkWrapper.py rename to tests/test_spark_wrapper.py index a907bf3..d0b451e 100644 --- a/tests/test_SparkWrapper.py +++ b/tests/test_spark_wrapper.py @@ -1,7 +1,7 @@ from unittest import TestCase from pyspark.sql import SparkSession from pyspark.sql import functions as F -from app.SparkWrapper import value_counts, rename_columns, create_frame, make_window +from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window class TestSparkWrapper(TestCase): diff --git a/tests/test_SparkWrapperFailure.py b/tests/test_spark_wrapper_failure.py similarity index 97% rename from tests/test_SparkWrapperFailure.py rename to tests/test_spark_wrapper_failure.py index 7425b42..8840012 100644 --- a/tests/test_SparkWrapperFailure.py +++ b/tests/test_spark_wrapper_failure.py @@ -3,7 +3,7 @@ from pyspark.sql import SparkSession from pyspark.sql import functions as F from pyspark.sql import utils as U -from app.SparkWrapper import value_counts, rename_columns, create_frame, make_window +from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window class TestSparkWrapper(TestCase):