Skip to content

Commit

Permalink
annotation fix: typing replacing | by Union (| only >3.9)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfouesneau committed Oct 4, 2024
1 parent b9313c4 commit f5322e2
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 66 deletions.
10 changes: 6 additions & 4 deletions src/ezpadova/deprecated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Deprecated functions from previous version that are kept here for backward compatibility."""

from typing import Union
from .parsec import get_isochrones
from .tools import deprecated_replacedby
import pandas as pd
Expand All @@ -8,7 +9,7 @@
@deprecated_replacedby("get_isochrones")
def get_one_isochrone(
age_yr=None, Z=None, logage=None, MH=None, ret_table=True, **kwargs
) -> pd.DataFrame | bytes:
) -> Union[pd.DataFrame, bytes]:
"""
Get one isochrone at a given time and metallicity content.
Parameters:
Expand Down Expand Up @@ -51,7 +52,7 @@ def get_one_isochrone(
@deprecated_replacedby("get_isochrones")
def get_Z_isochrones(
z0, z1, dz, age_yr=None, logage=None, ret_table=True, **kwargs
) -> pd.DataFrame | bytes:
) -> Union[pd.DataFrame, bytes]:
"""
Retrieve isochrones for a given metallicity range and age.
Expand Down Expand Up @@ -82,12 +83,13 @@ def get_Z_isochrones(
else:
query["logage"] = [logage, logage, 0]


return get_isochrones(return_df=ret_table, **query, **kwargs)


@deprecated_replacedby("get_isochrones")
def get_t_isochrones(logt0, logt1, dlogt, Z=None, MH=None, ret_table=True, **kwargs):
def get_t_isochrones(
logt0, logt1, dlogt, Z=None, MH=None, ret_table=True, **kwargs
) -> Union[pd.DataFrame, bytes]:
"""
Retrieve isochrones for a given age range and metallicity.
Expand Down
114 changes: 62 additions & 52 deletions src/ezpadova/parsec.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import re
import zlib
from typing import Tuple
from urllib.request import urlopen
from io import BufferedReader, BytesIO
from typing import Tuple, Union
from urllib.request import urlopen

import requests
import pandas as pd
import requests

from .config import configuration, validate_query_parameter
from .tools import get_file_archive_type
Expand Down Expand Up @@ -38,14 +38,16 @@ def build_query(**kwargs) -> dict:
kw.update(kwargs)

# update some keys to match the website requirements
if not kw['photsys_file'].endswith('.dat'):
kw['photsys_file'] = f"YBC_tab_mag_odfnew/tab_mag_{kw['photsys_file']}.dat"
if not kw['imf_file'].endswith('.dat'):
if not kw["photsys_file"].endswith(".dat"):
kw["photsys_file"] = f"YBC_tab_mag_odfnew/tab_mag_{kw['photsys_file']}.dat"
if not kw["imf_file"].endswith(".dat"):
kw["imf_file"] = f"tab_imf/imf_{kw['imf_file']}.dat"
return kw


def parse_result(data: str | bytes | BufferedReader, comment: str = '#') -> pd.DataFrame:
def parse_result(
data: Union[str, bytes, BufferedReader], comment: str = "#"
) -> pd.DataFrame:
"""
Parses the input data and returns a pandas DataFrame.
Expand All @@ -66,24 +68,28 @@ def parse_result(data: str | bytes | BufferedReader, comment: str = '#') -> pd.D
if isinstance(data, BufferedReader):
data = data.read()

split_txt = data.decode('utf-8').split('\n')
split_txt = data.decode("utf-8").split("\n")
for num, line in enumerate(split_txt):
if line[0] != comment:
break
start = num - 1
header = split_txt[start].replace('#', '').strip().split()
df = pd.read_csv(BytesIO(data), skiprows=start + 1, sep=r'\s+', names=header, comment='#')
df.attrs['comment'] = '\n'.join(k.replace('#', '').strip() for k in split_txt[:start])
header = split_txt[start].replace("#", "").strip().split()
df = pd.read_csv(
BytesIO(data), skiprows=start + 1, sep=r"\s+", names=header, comment="#"
)
df.attrs["comment"] = "\n".join(
k.replace("#", "").strip() for k in split_txt[:start]
)
return df


def query(**kwargs) -> bytes:
"""
Query the CMD webpage with the given parameters.
This function sends a POST request to the CMD webpage specified in the
configuration and retrieves the resulting data. The data is then processed
and returned as bytes. If the server response is incorrect or if there is
This function sends a POST request to the CMD webpage specified in the
configuration and retrieves the resulting data. The data is then processed
and returned as bytes. If the server response is incorrect or if there is
an issue with the data retrieval, a RuntimeError is raised.
Args:
Expand All @@ -93,85 +99,89 @@ def query(**kwargs) -> bytes:
bytes: The retrieved data from the CMD webpage.
Raises:
RuntimeError: If the server response is incorrect or if there is an
RuntimeError: If the server response is incorrect or if there is an
issue with data retrieval.
"""
print(f"Querying {configuration['url']}...")
kw = build_query(**kwargs)
req = requests.post(configuration["url"], data=kw, timeout=60, allow_redirects=True)
if req.status_code != 200:
raise RuntimeError('Server Response is incorrect')
raise RuntimeError("Server Response is incorrect")
else:
print('Retrieving data...')
print("Retrieving data...")

fname = re.compile(r'output\d+').findall(req.text)
domain = '/'.join(configuration['url'].split('/')[:3])
fname = re.compile(r"output\d+").findall(req.text)
domain = "/".join(configuration["url"].split("/")[:3])
if len(fname) > 0:
data_url = f'{domain}/tmp/{fname[0]}.dat'
print(f'Downloading data...{data_url}')
data_url = f"{domain}/tmp/{fname[0]}.dat"
print(f"Downloading data...{data_url}")
bf = urlopen(data_url)
r = bf.read()
typ = get_file_archive_type(r, stream=True)
if typ is not None:
r = zlib.decompress(bytes(r), 15 + 32)
return r
else:
print(configuration['url'], query)
print(configuration["url"], query)
print(req.text)
raise RuntimeError('Server Response is incorrect')
raise RuntimeError("Server Response is incorrect")


def get_isochrones(
age_yr: Tuple[float, float, float] | None = None,
Z: Tuple[float, float, float] | None = None,
logage: Tuple[float, float, float] | None = None,
MH: Tuple[float, float, float] | None = None,
default_ranges: bool = False,
return_df: bool = True,
**kwargs) -> pd.DataFrame | bytes:

age_yr: Union[Tuple[float, float, float], None] = None,
Z: Union[Tuple[float, float, float], None] = None,
logage: Union[Tuple[float, float, float], None] = None,
MH: Union[Tuple[float, float, float], None] = None,
default_ranges: bool = False,
return_df: bool = True,
**kwargs,
) -> Union[pd.DataFrame, bytes]:
kw = configuration["defaults"].copy()
kw.update(kwargs)

# default_ranges means using the forms default values - no parameters are provided
if not default_ranges:
# check parameter consistency
if age_yr is None and logage is None:
raise ValueError('Either age_yr or logage must be provided.')
raise ValueError("Either age_yr or logage must be provided.")
if age_yr is not None and logage is not None:
raise ValueError('Only one of age_yr or logage can be provided.')
raise ValueError("Only one of age_yr or logage can be provided.")

if Z is None and MH is None:
raise ValueError('Either Z or MH must be provided.')
raise ValueError("Either Z or MH must be provided.")
if Z is not None and MH is not None:
raise ValueError('Only one of Z or MH can be provided.')
raise ValueError("Only one of Z or MH can be provided.")

# check that any of the parameters are None or a triplet of Numbers
for name, param in zip(("age_yr", "Z", "logage", "MH"), (age_yr, Z, logage, MH)):
for name, param in zip(
("age_yr", "Z", "logage", "MH"), (age_yr, Z, logage, MH)
):
if param is not None:
if not isinstance(param, (list, tuple)) or len(param) != 3:
raise ValueError(f'Parameter {name} must be a triplet of Numbers or None. Found {param} instead.')
raise ValueError(
f"Parameter {name} must be a triplet of Numbers or None. Found {param} instead."
)

# setup linear age / log age query
if age_yr is not None:
kw['isoc_isagelog'] = 0
for key, val in zip(('isoc_agelow', 'isoc_ageupp', 'isoc_dage'), age_yr):
kw["isoc_isagelog"] = 0
for key, val in zip(("isoc_agelow", "isoc_ageupp", "isoc_dage"), age_yr):
kw[key] = val
else:
kw['isoc_isagelog'] = 1
for key, val in zip(('isoc_lagelow', 'isoc_lageupp', 'isoc_dlage'), logage):
kw["isoc_isagelog"] = 1
for key, val in zip(("isoc_lagelow", "isoc_lageupp", "isoc_dlage"), logage):
kw[key] = val

# setup metallicity query in Z or [M/H]
if Z is not None:
kw['isoc_ismetlog'] = 0
for key, val in zip(('isoc_zlow', 'isoc_zupp', 'isoc_dz'), Z):
kw["isoc_ismetlog"] = 0
for key, val in zip(("isoc_zlow", "isoc_zupp", "isoc_dz"), Z):
kw[key] = val
else:
kw['isoc_ismetlog'] = 1
for key, val in zip(('isoc_metlow', 'isoc_metupp', 'isoc_dmet'), MH):
kw["isoc_ismetlog"] = 1
for key, val in zip(("isoc_metlow", "isoc_metupp", "isoc_dmet"), MH):
kw[key] = val

# check parameters validity
validate_query_parameter(**kw)

Expand All @@ -183,9 +193,9 @@ def get_isochrones(
return parse_result(res)
else:
return res


if __name__ == "__main__":
df = get_isochrones(default_ranges=True)
print(df)
print(df.attrs['comment'])
print(df.attrs["comment"])
28 changes: 18 additions & 10 deletions src/ezpadova/tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from io import BytesIO
from functools import wraps
import warnings
from functools import wraps
from io import BytesIO
from typing import Union


def deprecated_replacedby(replace_by):
""" This is a decorator which can be used to mark functions as deprecated. """
"""This is a decorator which can be used to mark functions as deprecated."""

def decorator(func):
msg = f"{func.__name__} is deprecated and will be removed in a future version. Use {replace_by} instead."
Expand All @@ -13,16 +14,19 @@ def decorator(func):
def new_func(*args, **kwargs):
warnings.warn(msg, DeprecationWarning)
return func(*args, **kwargs)

# add to docstring
new_func.__doc__ = f"{msg}\n\n{func.__doc__}"

return new_func

return decorator


def get_file_archive_type(filename: str | BytesIO, stream: bool = False) -> str | None:
""" Detect the type of a potentially compressed file.
def get_file_archive_type(
filename: Union[str, BytesIO], stream: bool = False
) -> str | None:
"""Detect the type of a potentially compressed file.
This function checks the beginning of a file to determine if it is compressed
using gzip, bzip2, or zip formats. It returns the corresponding file type if
Expand All @@ -37,11 +41,15 @@ def get_file_archive_type(filename: str | BytesIO, stream: bool = False) -> str
str | None: The type of compression detected ('gz', 'bz2', 'zip'), or None if
no compression is detected.
"""
magic_dict = { b"\x1f\x8b\x08": "gz", b"\x42\x5a\x68": "bz2", b"\x50\x4b\x03\x04": "zip" }
magic_dict = {
b"\x1f\x8b\x08": "gz",
b"\x42\x5a\x68": "bz2",
b"\x50\x4b\x03\x04": "zip",
}

max_len = max(len(x) for x in magic_dict)
if not stream:
with open(filename, 'rb') as f:
with open(filename, "rb") as f:
file_start = f.read(max_len)
else:
try:
Expand All @@ -55,4 +63,4 @@ def get_file_archive_type(filename: str | BytesIO, stream: bool = False) -> str
if file_start.startswith(magic):
return filetype

return None
return None

0 comments on commit f5322e2

Please sign in to comment.