Skip to content

Commit

Permalink
Merge pull request #7 from wednesday-solutions/feat/refactoring
Browse files Browse the repository at this point in the history
Fix: changed dbutils to be used in a functions & refactored code
  • Loading branch information
vighnesh-wednesday authored Dec 14, 2023
2 parents 16fbbae + 58922f3 commit 3d4bf2a
Show file tree
Hide file tree
Showing 9 changed files with 41 additions and 86 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ jobs:

- name: Lint
run: |
pylint app tests main.py setup.py --output pylint-report.txt
pylint app tests main.py setup.py
pylint app tests main.py setup.py --output pylint-report.txt
- name: Testing
run: |
Expand Down
3 changes: 3 additions & 0 deletions app/connect_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ def init_databricks():

return spark, dbutils

def get_param_value(dbutils, param_key):
return dbutils.widgets.get(param_key)


def create_mount(dbutils, container_name, mount_path):
storage_name = os.environ["storage_account_name"]
Expand Down
File renamed without changes.
6 changes: 2 additions & 4 deletions app/SparkWrapper.py → app/spark_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from pyspark.sql import SparkSession, DataFrame
from pyspark.sql import DataFrame
from pyspark.sql import Window, WindowSpec


def create_frame(sc: SparkSession | None, path: str):
if sc is None:
raise TypeError(f"{sc} is None. Pass Spark Session")
def create_frame(sc, path: str):
df = sc.read.csv(path, inferSchema=True, header=True)
return df

Expand Down
52 changes: 29 additions & 23 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,54 @@
from pyspark.sql import Window
import pyspark.sql.functions as F

import app.SparkWrapper as sw
import app.spark_wrapper as sw

os.system("pip install python-dotenv")
import dotenv # pylint: disable=wrong-import-position, disable=wrong-import-order

# COMMAND ----------

try:
import app.connect_databricks as cd # pylint: disable=ungrouped-imports
import json
# try:
# import app.connect_databricks as cd # pylint: disable=ungrouped-imports
# import json

# Comment the following line if running directly in cloud notebook
spark, dbutils = cd.init_databricks()
# # Comment the following line if running directly in cloud notebook
# spark, dbutils = cd.init_databricks()

with open("/dbfs/mnt/config/keys.json", encoding="utf-8") as file:
keys = json.load(file)
# with open("/dbfs/mnt/config/keys.json", encoding="utf-8") as file:
# keys = json.load(file)

flag = keys["flag"]
except: # pylint: disable=bare-except
flag = "False"
# flag = keys["flag"]
# except: # pylint: disable=bare-except
# flag = "False"


flag = bool(flag == "True")
# flag = bool(flag == "True")

# if 'spark' in locals():
# flag = True
# else:
# spark = None
# dbutils = None
# flag = False
if "dbutils" in locals():
flag = True
else:
spark = None
dbutils = None
flag = False


# COMMAND ----------

if flag:
os.environ["KAGGLE_USERNAME"] = dbutils.widgets.get("kaggle_username")
import app.connect_databricks as cd

os.environ["KAGGLE_KEY"] = dbutils.widgets.get("kaggle_token")
os.environ["KAGGLE_USERNAME"] = cd.get_param_value(dbutils, "kaggle_username")

os.environ["storage_account_name"] = dbutils.widgets.get("storage_account_name")
os.environ["KAGGLE_KEY"] = cd.get_param_value(dbutils, "kaggle_token")

os.environ["datalake_access_key"] = dbutils.widgets.get("datalake_access_key")
os.environ["storage_account_name"] = cd.get_param_value(
dbutils, "storage_account_name"
)

os.environ["datalake_access_key"] = cd.get_param_value(
dbutils, "datalake_access_key"
)


# COMMAND ----------
Expand Down Expand Up @@ -84,7 +90,7 @@


# COMMAND ----------
from app.Extraction import extract_from_kaggle # pylint: disable=wrong-import-position
from app.extraction import extract_from_kaggle # pylint: disable=wrong-import-position

# COMMAND ----------

Expand Down
52 changes: 0 additions & 52 deletions tests/connect_glue_test.py

This file was deleted.

8 changes: 4 additions & 4 deletions tests/test_Extraction.py → tests/test_extraction.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import unittest
from unittest.mock import patch
from app.Extraction import extract_from_kaggle
from app.extraction import extract_from_kaggle


class TestExtraction(unittest.TestCase):
@patch("app.Extraction.kaggle")
@patch("app.extraction.kaggle")
def test_extract_from_kaggle_success(self, mock_kaggle):
mock_kaggle_instance = mock_kaggle
mock_api_instance = mock_kaggle_instance.KaggleApi.return_value
Expand All @@ -26,7 +26,7 @@ def test_extract_from_kaggle_success(self, mock_kaggle):

self.assertEqual(result, ("/mnt/rawdata/", "/mnt/transformed/"))

@patch("app.Extraction.kaggle")
@patch("app.extraction.kaggle")
def test_extract_from_kaggle_success_false(self, mock_kaggle):
mock_kaggle_instance = mock_kaggle
mock_api_instance = mock_kaggle_instance.KaggleApi.return_value
Expand All @@ -52,7 +52,7 @@ def test_extract_from_kaggle_success_false(self, mock_kaggle):
),
)

@patch("app.Extraction.kaggle")
@patch("app.extraction.kaggle")
def test_extract_from_kaggle_failure(self, mock_kaggle):
mock_kaggle_instance = mock_kaggle
mock_api_instance = mock_kaggle_instance.KaggleApi.return_value
Expand Down
2 changes: 1 addition & 1 deletion tests/test_SparkWrapper.py → tests/test_spark_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from unittest import TestCase
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from app.SparkWrapper import value_counts, rename_columns, create_frame, make_window
from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window


class TestSparkWrapper(TestCase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import utils as U
from app.SparkWrapper import value_counts, rename_columns, create_frame, make_window
from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window


class TestSparkWrapper(TestCase):
Expand Down

0 comments on commit 3d4bf2a

Please sign in to comment.