Skip to content

Commit

Permalink
code format black
Browse files Browse the repository at this point in the history
  • Loading branch information
vighnesh-wednesday committed Dec 14, 2023
1 parent 2a671b2 commit ab0bb1c
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tests/test_connect_databricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import unittest
from unittest.mock import MagicMock, patch
from app.connect_databricks import create_mount, get_param_value


def constructor(self, mount_point):
self.mountPoint = mount_point


MockMount = type("MockMount", (object,), {"__init__": constructor})


class TestMountFunctions(unittest.TestCase):
def setUp(self):
# Mocking dbutils
self.dbutils = MagicMock()

@patch(
"app.connect_databricks.os.environ",
{
"storage_account_name": "mock_storage_account",
"datalake_access_key": "mock_access_key",
},
)
def test_create_mount_success(self):
container_name = "mock_container"
mount_path = "/mnt/mock_mount_point"

# Mocking fs.mounts() to return an empty list
self.dbutils.fs.mounts.return_value = []

# Call the function to test
create_mount(self.dbutils, container_name, mount_path)

# Assertions
self.dbutils.fs.mount.assert_called_once_with(
source=f"wasbs://{container_name}@mock_storage_account.blob.core.windows.net/",
mount_point=mount_path,
extra_configs={
"fs.azure.account.key.mock_storage_account.blob.core.windows.net": "mock_access_key"
},
)
self.dbutils.fs.refreshMounts.assert_not_called()

@patch(
"app.connect_databricks.os.environ",
{
"storage_account_name": "mock_storage_account",
"datalake_access_key": "mock_access_key",
},
)
def test_create_mount_already_mounted(self):
container_name = "mock_container"
mount_path = "/mnt/mock_mount_point"

# Mocking fs.mounts() to return a list with the mount path
mocked_mount = MockMount(mount_path)
self.dbutils.fs.mounts.return_value = [mocked_mount]

# Call the function to test
create_mount(self.dbutils, container_name, mount_path)

# Assertions
self.dbutils.fs.mount.assert_not_called()
self.dbutils.fs.refreshMounts.assert_called_once()

def test_get_param_value_success(self):
param_key = "mock_param_key"
mock_param_value = "mock_param_value"

# Mocking dbutils.widgets.get() to return a value
self.dbutils.widgets.get.return_value = mock_param_value

# Call the function to test
result = get_param_value(self.dbutils, param_key)

# Assertions
self.assertEqual(result, mock_param_value)
self.dbutils.widgets.get.assert_called_once_with(param_key)

def test_get_param_value_failure(self):
param_key = "mock_param_key"

# Mocking dbutils.widgets.get() to return None (indicating failure)
self.dbutils.widgets.get.return_value = None

# Call the function to test
result = get_param_value(self.dbutils, param_key)

# Assertions
self.assertIsNone(result)
self.dbutils.widgets.get.assert_called_once_with(param_key)

0 comments on commit ab0bb1c

Please sign in to comment.