Skip to content

Commit

Permalink
added environment test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
vighnesh-wednesday committed Dec 16, 2023
1 parent eaa8fa1 commit 7b361e7
Show file tree
Hide file tree
Showing 5 changed files with 312 additions and 3 deletions.
File renamed without changes.
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pyspark.sql import Window
import pyspark.sql.functions as F

import app.enviroment as env
import app.environment as env
import app.spark_wrapper as sw

load_dotenv("app/.custom-env")
Expand Down
263 changes: 263 additions & 0 deletions tests/test_environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
import subprocess
import unittest
from unittest.mock import patch, MagicMock
from app.environment import (
set_keys_get_spark,
get_dataframes,
get_read_path,
get_write_path,
get_data,
)


class TestSetKeysGetSpark(unittest.TestCase):
@patch("app.environment.cg.init_glue")
@patch("app.environment.cd.create_mount")
@patch("app.environment.dotenv.load_dotenv")
def test_databricks_environment(
self, mock_load_dotenv, mock_create_mount, mock_init_glue
):
# Mocking dbutils and spark
dbutils = MagicMock()
spark = MagicMock()

# Mocking the widgets
dbutils.widgets.get = MagicMock(
side_effect=lambda key: {
"kaggle_username": "mock_username",
"kaggle_token": "mock_token",
"storage_account_name": "mock_account_name",
"datalake_access_key": "mock_access_key",
}[key]
)

# Call the function
result_spark = set_keys_get_spark(True, dbutils, spark)

# Assert
self.assertEqual(result_spark, spark)
dbutils.widgets.get.assert_called_with("datalake_access_key")
mock_create_mount.assert_called_with(
dbutils, "transformed", "/mnt/transformed/"
)
mock_init_glue.assert_not_called()
mock_load_dotenv.assert_not_called()

@patch("app.environment.cg.init_glue")
@patch("app.environment.cd.create_mount")
@patch("app.environment.dotenv.load_dotenv")
def test_glue_local_environment(
self, mock_load_dotenv, mock_create_mount, mock_init_glue
):
# Mocking dbutils and spark
dbutils = MagicMock()
spark = MagicMock()

mock_spark, mock_args = MagicMock(), {"JOB_NAME": "local"}
mock_init_glue.return_value = (mock_spark, mock_args)

# Call the function
result_spark = set_keys_get_spark(False, dbutils, spark)

# Assert
self.assertEqual(result_spark, mock_spark)
dbutils.widgets.get.assert_not_called()
mock_create_mount.assert_not_called()
mock_init_glue.assert_called_once()
mock_load_dotenv.assert_called_once()

@patch("app.environment.cg.init_glue")
@patch("app.environment.cd.create_mount")
@patch("app.environment.dotenv.load_dotenv")
def test_glue_online_environment(
self, mock_load_dotenv, mock_create_mount, mock_init_glue
):
# Mocking dbutils and spark
dbutils = MagicMock()
spark = MagicMock()

mock_spark, mock_args = MagicMock(), {
"JOB_NAME": "online",
"KAGGLE_USERNAME": "mock_name",
"KAGGLE_KEY": "mock_key",
}
mock_init_glue.return_value = (mock_spark, mock_args)

# mock_args['JOB_NAME'] = "local"

# Call the function
result_spark = set_keys_get_spark(False, dbutils, spark)

# Assert
self.assertEqual(result_spark, mock_spark)
dbutils.widgets.get.assert_not_called()
mock_create_mount.assert_not_called()
mock_init_glue.assert_called_once()
mock_load_dotenv.assert_not_called()

@patch("app.environment.sw.create_frame")
@patch("os.listdir")
@patch("subprocess.run")
def test_databricks_dataframes(self, mock_run, mock_listdir, mock_create_frame):
# Mocking spark
spark = MagicMock()

# Mocking directory_path
directory_path = "/mnt/rawdata"

# Mocking csv_files
mock_listdir.return_value = ["file1.csv", "file2.csv", "file3.parquet"]

# Mock create_frame function
mock_create_frame.return_value = MagicMock()

# Call the function
result_df_list = get_dataframes(True, spark, directory_path)

# Assertions
self.assertEqual(len(result_df_list), 2)
mock_listdir.assert_called_with(directory_path)
mock_create_frame.assert_any_call(spark, "/mnt/rawdata/file1.csv")
mock_create_frame.assert_any_call(spark, "/mnt/rawdata/file2.csv")
mock_run.assert_not_called()

@patch("app.environment.sw.create_frame")
@patch("os.listdir")
@patch("subprocess.run")
def test_glue_dataframes(self, mock_run, mock_listdir, mock_create_frame):
# Mocking spark
spark = MagicMock()

# Mocking directory_path
directory_path = "/local/path"

# Mocking subprocess result
mock_run.return_value.stdout = "file1.csv\nfile2.csv\nfile3.json"

# Mock create_frame function
mock_create_frame.return_value = MagicMock()

# Call the function
result_df_list = get_dataframes(False, spark, directory_path)

# Assertions
self.assertEqual(len(result_df_list), 2)
mock_run.assert_called_with(
f"aws s3 ls {directory_path}",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
shell=True,
check=True,
)
mock_create_frame.assert_any_call(spark, "/local/path/file1.csv")
mock_create_frame.assert_any_call(spark, "/local/path/file2.csv")
mock_listdir.assert_not_called()

@patch("app.environment.os.getenv")
def test_read_path_databricks(self, mock_os_getenv):
mock_os_getenv.return_value = "path/to/databricks_read"

result = get_read_path(True)

self.assertEqual(result, "path/to/databricks_read")
mock_os_getenv.assert_called_once_with("DATABRICKS_READ_PATH")

@patch("app.environment.os.getenv")
def test_write_path_databricks(self, mock_os_getenv):
mock_os_getenv.return_value = "path/to/databricks_write"

result = get_write_path(True)

self.assertEqual(result, "path/to/databricks_write")
mock_os_getenv.assert_called_once_with("DATABRICKS_WRITE_PATH")

@patch("app.environment.os.getenv")
def test_read_path_glue(self, mock_os_getenv):
mock_os_getenv.return_value = "path/to/glue_read"

result = get_read_path(False)

self.assertEqual(result, "path/to/glue_read")
mock_os_getenv.assert_called_once_with("GLUE_READ_PATH")

@patch("app.environment.os.getenv")
def test_write_path_glue(self, mock_os_getenv):
mock_os_getenv.return_value = "path/to/glue_write"

result = get_write_path(False)

self.assertEqual(result, "path/to/glue_write")
mock_os_getenv.assert_called_once_with("GLUE_WRITE_PATH")

@patch("app.environment.set_keys_get_spark")
@patch("app.environment.get_read_path")
@patch("app.environment.get_dataframes")
@patch("app.extraction.extract_from_kaggle")
def test_kaggle_extraction_enabled(
self,
mock_extract_from_kaggle,
mock_get_dataframes,
mock_get_read_path,
mock_set_keys_get_spark,
):
# Mocking parameters
databricks = True
kaggle_extraction = True
dbutils = MagicMock()
spark = MagicMock()

# Mocking set_keys_get_spark function
mock_set_keys_get_spark.return_value = spark

# Mocking get_read_path function
mock_get_read_path.return_value = "/mnt/rawdata"

# Mocking extract_from_kaggle function
mock_get_dataframes.return_value = [MagicMock(), MagicMock()]

# Call the function
result_data = get_data(databricks, kaggle_extraction, dbutils, spark)

# Assertions
self.assertEqual(len(result_data), 2)
mock_set_keys_get_spark.assert_called_once_with(databricks, dbutils, spark)
mock_get_read_path.assert_called_once_with(databricks)
mock_extract_from_kaggle.assert_called_once_with(databricks, "/mnt/rawdata")
mock_get_dataframes.assert_called_once_with(databricks, spark, "/mnt/rawdata")

@patch("app.environment.set_keys_get_spark")
@patch("app.environment.get_read_path")
@patch("app.environment.get_dataframes")
@patch("app.extraction.extract_from_kaggle")
def test_kaggle_extraction_disabled(
self,
mock_extract_from_kaggle,
mock_get_dataframes,
mock_get_read_path,
mock_set_keys_get_spark,
):
# Mocking parameters
databricks = False
kaggle_extraction = False
dbutils = MagicMock()
spark = MagicMock()

# Mocking set_keys_get_spark function
mock_set_keys_get_spark.return_value = spark

# Mocking get_read_path function
mock_get_read_path.return_value = "/local/path"

# Mocking extract_from_kaggle function
mock_get_dataframes.return_value = [MagicMock(), MagicMock()]

# Call the function
result_data = get_data(databricks, kaggle_extraction, dbutils, spark)

# Assertions
self.assertEqual(len(result_data), 2)
mock_set_keys_get_spark.assert_called_once_with(databricks, dbutils, spark)
mock_get_read_path.assert_called_once_with(databricks)
mock_extract_from_kaggle.assert_not_called()
mock_get_dataframes.assert_called_once_with(databricks, spark, "/local/path")
34 changes: 33 additions & 1 deletion tests/test_spark_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from unittest import TestCase
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window
from app.spark_wrapper import (
value_counts,
rename_columns,
create_frame,
make_window,
rename_same_columns,
)


class TestSparkWrapper(TestCase):
Expand Down Expand Up @@ -103,3 +109,29 @@ def test_make_window(self):
for actual, expected in zip(actual_data, expected_data):
for col_name in expected.keys():
self.assertEqual(actual[col_name], expected[col_name])

def test_rename_same_columns(self):
df = self.df
df = df.withColumn("ADDRESS_LINE1", F.lit("123 Main St"))
df = df.withColumn("ADDRESS_LINE2", F.lit("Apt 456"))
df = df.withColumn("CITY", F.lit("Cityville"))
df = df.withColumn("STATE", F.lit("CA"))
df = df.withColumn("POSTAL_CODE", F.lit("12345"))

df = rename_same_columns(df, "CUSTOMER")

actual_columns = df.columns

expected_columns = [
"stock_name",
"market",
"close_price",
"date",
"CUSTOMER_ADDRESS_LINE1",
"CUSTOMER_ADDRESS_LINE2",
"CUSTOMER_CITY",
"CUSTOMER_STATE",
"CUSTOMER_POSTAL_CODE",
]

self.assertListEqual(actual_columns, expected_columns)
16 changes: 15 additions & 1 deletion tests/test_spark_wrapper_failure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql import utils as U
from app.spark_wrapper import value_counts, rename_columns, create_frame, make_window
from app.spark_wrapper import (
value_counts,
rename_columns,
create_frame,
make_window,
rename_same_columns,
)


class TestSparkWrapper(TestCase):
Expand Down Expand Up @@ -71,3 +77,11 @@ def test_rename_column_invalid_datatype(self):
expected_error_message = "WRONG DATATYPE"
actual_error_message = str(context.exception)
self.assertTrue(expected_error_message in actual_error_message)

def test_rename_same_column_failure(self):
with self.assertRaises(ValueError) as context:
rename_same_columns(self.df, "VENDOR")

expected_error_message = "COLUMN DOESN'T EXIST"
actual_error_message = str(context.exception)
self.assertTrue(expected_error_message in actual_error_message)

0 comments on commit 7b361e7

Please sign in to comment.