Skip to content

Commit

Permalink
adding more error handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
Masha Iureva authored and Masha Iureva committed Oct 21, 2024
1 parent e59c64b commit 07c2367
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 34 deletions.
3 changes: 3 additions & 0 deletions scripts/validate_mixer/config_handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import yaml
import json
import os
from typing import Dict, Any, List, Union, Type
from env_handler import expand_env_vars_in_config

def load_config(config_path: str) -> Dict[str, Any]:
"""Load the configuration file (YAML or JSON)."""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config file not found at path: {config_path}")
try:
with open(config_path, 'r') as file:
if config_path.endswith('.yaml') or config_path.endswith('.yml'):
Expand Down
64 changes: 40 additions & 24 deletions scripts/validate_mixer/file_operations.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import os
import random
import re
import json
import itertools
from typing import Optional, List, Dict, Any, Tuple
from tqdm import tqdm
import boto3
import json
import itertools
import smart_open
from botocore.exceptions import ClientError

from s3_utils import s3_client, list_s3_objects, get_base_path, get_corresponding_attribute_path
from utils import vprint

class FileDownloadError(Exception):
pass

def sample_files(s3_path: str, num_samples: int) -> List[str]:
"""Sample a subset of files from an S3 path."""
all_files = list(list_s3_objects(s3_path))
Expand All @@ -21,7 +26,14 @@ def sample_files(s3_path: str, num_samples: int) -> List[str]:

def download_file(s3_path: str, local_path: str) -> None:
bucket, key = s3_path.replace("s3://", "").split("/", 1)
s3_client.download_file(bucket, key, local_path)
try:
s3_client.download_file(bucket, key, local_path)
except ClientError as e:
if e.response['Error']['Code'] == '404':
raise FileDownloadError(f"File not found: {s3_path}")
else:
raise FileDownloadError(f"Error downloading file {s3_path}: {str(e)}")


def sample_and_download_files(stream: Dict[str, Any], num_samples: int) -> Tuple[List[str], Dict[str, List[str]]]:
temp_dir = "temp_sample_files"
Expand All @@ -43,31 +55,35 @@ def sample_and_download_files(stream: Dict[str, Any], num_samples: int) -> Tuple
local_attr_samples_dict = {attr_type: [] for attr_type in stream['attributes']}

for doc_sample in doc_samples:
local_doc_path = os.path.join(temp_dir, os.path.basename(doc_sample))
download_file(doc_sample, local_doc_path)
local_doc_samples.append(local_doc_path)
pbar.update(1)

# Extract the base name and extension
base_name, extension = os.path.splitext(os.path.basename(doc_sample))
if extension == '.gz':
# Handle double extensions like .jsonl.gz
base_name, inner_extension = os.path.splitext(base_name)
extension = inner_extension + extension

for attr_type in stream['attributes']:
attr_sample = get_corresponding_attribute_path(doc_sample, base_doc_path, base_attr_path, attr_type)
# Construct the new filename with the attribute type before the extension, using a hyphen
new_filename = f"{base_name}-{attr_type}{extension}"
local_attr_path = os.path.join(temp_dir, new_filename)
download_file(attr_sample, local_attr_path)
local_attr_samples_dict[attr_type].append(local_attr_path)
try:
local_doc_path = os.path.join(temp_dir, os.path.basename(doc_sample))
download_file(doc_sample, local_doc_path)
local_doc_samples.append(local_doc_path)
pbar.update(1)


# Extract the base name and extension
base_name, extension = os.path.splitext(os.path.basename(doc_sample))
if extension == '.gz':
# Handle double extensions like .jsonl.gz
base_name, inner_extension = os.path.splitext(base_name)
extension = inner_extension + extension

for attr_type in stream['attributes']:
attr_sample = get_corresponding_attribute_path(doc_sample, base_doc_path, base_attr_path, attr_type)
# Construct the new filename with the attribute type before the extension, using a hyphen
new_filename = f"{base_name}-{attr_type}{extension}"
local_attr_path = os.path.join(temp_dir, new_filename)
download_file(attr_sample, local_attr_path)
local_attr_samples_dict[attr_type].append(local_attr_path)
pbar.update(1)
except FileDownloadError as e:
print(f"Warning: {str(e)}. Skipping this file and its attributes.")
continue

return local_doc_samples, local_attr_samples_dict

except Exception as e:
print(f"An error occurred: {str(e)}")
print(f"An error occurred during file sampling and downloading: {str(e)}")
raise

def count_file_lines(file_path: str) -> int:
Expand Down
4 changes: 2 additions & 2 deletions scripts/validate_mixer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main(config_path, num_samples, verbose):

if not validate_s3_paths_and_permissions(config):
print("S3 path validation FAILED")
return
# return

if not validate_stream_filters(config):
print("Filter validation FAILED.\n")
Expand All @@ -31,7 +31,7 @@ def main(config_path, num_samples, verbose):
print("Document and attribute validation FAILED")
return

print("Validation SUCCEEDED!")
print("Validation FINISHED!")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Validate mixer configuration")
Expand Down
46 changes: 38 additions & 8 deletions scripts/validate_mixer/validator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import os
import shutil
import sys
from typing import Dict, List, Tuple, Any, Optional
from dotenv import load_dotenv

Expand Down Expand Up @@ -34,7 +35,16 @@ def load_and_validate_config(config_path):
load_env_variables()

vprint("Validating configuration file...")
config = load_config(config_path)
try:
config = load_config(config_path)

except FileNotFoundError as e:
print(str(e))
print("Please check the file path and try again.")
sys.exit(1)
except ValueError as e:
print(f"Error loading or validating config: {str(e)}")
sys.exit(1)

vprint("Validating configuration structure...")
errors = validate_config_structure(config)
Expand Down Expand Up @@ -121,8 +131,7 @@ def validate_stream_filters(config: Dict[str, Any]) -> bool:

return all_valid

# def validate_documents_and_attributes(config: Dict[str, Any], num_samples: int) -> bool:
def validate_documents_and_attributes(config, num_samples):
def validate_documents_and_attributes(config: Dict[str, Any], num_samples: int) -> bool:
vprint("Sampling files...")
temp_dir = "temp_sample_files"
try:
Expand All @@ -135,16 +144,24 @@ def validate_documents_and_attributes(config, num_samples):

base_doc_path = get_base_path(stream['documents'][0])
base_attr_path = re.sub(r'/documents($|/)', r'/attributes\1', base_doc_path)

doc_samples, attr_samples_dict = sample_and_download_files(stream, num_samples)

try:
doc_samples, attr_samples_dict = sample_and_download_files(stream, num_samples)
except Exception as e:
print(f"Error during file sampling and downloading: {str(e)}")
return False

if not doc_samples:
print("No document samples were successfully downloaded. Skipping further validation for this stream.")
continue

for doc_sample in doc_samples:
vprint(f"\nValidating file: {doc_sample}")

doc_line_count = count_file_lines(doc_sample)
if doc_line_count == -1:
print("Failed to count lines in document file. Check the file and try again.")
return False
print(f"Failed to count lines in document file {doc_sample}. Skipping the file")
continue

vprint(f"Document has {doc_line_count} lines")

Expand All @@ -157,7 +174,20 @@ def validate_documents_and_attributes(config, num_samples):
return False

for attr_type in stream['attributes']:
attr_sample = attr_samples_dict[attr_type][doc_samples.index(doc_sample)]
if attr_type not in attr_samples_dict or not attr_samples_dict[attr_type]:
print(f"Warning: No attribute samples found for {attr_type}. Skipping validation for this attribute type.")
continue

try:
doc_index = doc_samples.index(doc_sample)
if doc_index >= len(attr_samples_dict[attr_type]):
print(f"Warning: No corresponding attribute file for document {doc_sample} and attribute type {attr_type}. Skipping validation for this attribute.")
continue
attr_sample = attr_samples_dict[attr_type][doc_index]
except ValueError:
print(f"Warning: Document {doc_sample} not found in samples. Skipping validation for this document.")
continue

vprint(f"\nValidating attribute file: {attr_sample}")

attr_line_count = count_file_lines(attr_sample)
Expand Down

0 comments on commit 07c2367

Please sign in to comment.