Skip to content

Commit

Permalink
[#983] cleaned up spectro/_fit12d_dinput.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dvezinet committed Nov 13, 2024
1 parent 34c13b1 commit 958c486
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
14 changes: 10 additions & 4 deletions tofu/spectro/_analysis_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
]


_LTYPES = (int, float, np.integer, np.float64)

_GITHUB = 'https://github.com/ToFuProject/tofu/issues'
_WINTIT = f'tofu-{__version__}\treport issues / requests at {_GITHUB}'


# Useful scalar types
_NINT = (np.int32, np.int64)
_INT = (int,) + _NINT
_NFLOAT = (np.float32, np.float64)
_FLOAT = (float,) + _NFLOAT
_NUMB = _INT + _FLOAT


###########################################################
###########################################################
# hidden utilities
Expand Down Expand Up @@ -127,7 +133,7 @@ def _get_localextrema_1d_check(
width = 0.
else:
width = False
c0 = width is False or (isinstance(width, _LTYPES) and width >= 0.)
c0 = width is False or (isinstance(width, _NUMB) and width >= 0.)
if not c0:
msg = (
"Arg width must be a float\n"
Expand All @@ -146,7 +152,7 @@ def _get_localextrema_1d_check(

if rel_height is None:
rel_height = 0.8
if not (isinstance(rel_height, _LTYPES) and 0 <= rel_height <= 1.):
if not (isinstance(rel_height, _NUMB) and 0 <= rel_height <= 1.):
msg = (
"Arg rel_height must be positive float in [0, 1]!\n"
+ "Provided: {}".format(rel_height)
Expand Down
59 changes: 33 additions & 26 deletions tofu/spectro/_fit12d_dinput.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
_SUBSET = False
_VALID_NSIGMA = 6.
_VALID_FRACTION = 0.8
_LTYPES = [int, float, np.int_, np.float64]
_DBOUNDS = {
'bck_amp': (0., 3.),
'bck_rate': (-3., 3.),
Expand Down Expand Up @@ -75,6 +74,14 @@
}


# Useful scalar types
_NINT = (np.int32, np.int64)
_INT = (int,) + _NINT
_NFLOAT = (np.float32, np.float64)
_FLOAT = (float,) + _NFLOAT
_NUMB = _INT + _FLOAT


###########################################################
###########################################################
#
Expand Down Expand Up @@ -193,13 +200,13 @@ def _checkformat_dconstants(dconstants=None, dconstraints=None):
and isinstance(v0, dict)
and all([
k1 in dconstraints[k0].keys()
and type(v1) in _LTYPES
and isinstance(v1, _NUMB)
for k1, v1 in v0.items()
])
)
or (
k0 not in _DORDER
and type(v0) in _LTYPES
and isinstance(v0, _NUMB)
)
)
)
Expand Down Expand Up @@ -240,7 +247,7 @@ def _dconstraints_double(dinput, dconstraints, defconst=_DCONSTRAINTS):
or (
isinstance(dinput['double'], dict)
and all([
kk in ['dratio', 'dshift'] and type(vv) in _LTYPES
kk in ['dratio', 'dshift'] and isinstance(vv, _NUMB)
for kk, vv in dinput['double'].items()
])
)
Expand Down Expand Up @@ -652,14 +659,14 @@ def _checkformat_domain(domain=None, keys=['lamb', 'phi']):
type(v0) in ltypesin + ltypesout
and (
(
all([type(v1) in _LTYPES for v1 in v0])
all([isinstance(v1, _NUMB) for v1 in v0])
and len(v0) == 2
and v0[1] > v0[0]
)
or (
all([
type(v1) in ltypesin + ltypesout
and all([type(v2) in _LTYPES for v2 in v1])
and all([isinstance(v2, _NUMB) for v2 in v1])
and len(v1) == 2
and v1[1] > v1[0]
for v1 in v0
Expand Down Expand Up @@ -781,7 +788,7 @@ def _binning_check(
raise Exception(msg2)

# Check which format was passed and return None or dict
ltypes0 = _LTYPES
ltypes0 = _NUMB
ltypes1 = [tuple, list, np.ndarray]
lc = [
binning is False,
Expand Down Expand Up @@ -1358,7 +1365,7 @@ def _extract_lphi_spectra(
def _checkformat_possubset(pos=None, subset=None):
if pos is None:
pos = _POS
c0 = isinstance(pos, bool) or type(pos) in _LTYPES
c0 = isinstance(pos, bool) or isinstance(pos, _NUMB)
if not c0:
msg = ("Arg pos must be either:\n"
+ "\t- False: no positivity constraints\n"
Expand Down Expand Up @@ -1716,13 +1723,13 @@ def _dvalid_checkfocus(
return False

# Check focus and transform to array of floats
if isinstance(focus, tuple([str] + _LTYPES)):
if isinstance(focus, (str,) + _NUMB):
focus = [focus]

lc = [
isinstance(focus, (list, tuple, np.ndarray))
and all([
(isinstance(ff, tuple(_LTYPES)) and ff > 0.)
(isinstance(ff, _NUMB) and ff > 0.)
or (isinstance(ff, str) and ff in lines_keys)
for ff in focus
]),
Expand Down Expand Up @@ -1755,11 +1762,11 @@ def _dvalid_checkfocus(
focus_half_width = (np.nanmax(lamb) - np.nanmin(lamb))/10.

lc0 = [
type(focus_half_width) in _LTYPES,
isinstance(focus_half_width, _NUMB),
(
type(focus_half_width) in [list, tuple, np.ndarray]
and len(focus_half_width) == focus.size
and all([type(fhw) in _LTYPES for fhw in focus_half_width])
and all([isinstance(fhw, _NUMB) for fhw in focus_half_width])
)
]

Expand Down Expand Up @@ -1974,7 +1981,7 @@ def _checkformat_dlines(dlines=None, domain=None):
and isinstance(v0, dict)
and 'lambda0' in v0.keys()
and (
type(v0['lambda0']) in _LTYPES
isinstance(v0['lambda0'], _NUMB)
or (
isinstance(v0['lambda0'], np.ndarray)
and v0['lambda0'].size == 1
Expand Down Expand Up @@ -2540,16 +2547,16 @@ def _fit12d_checkformat_dscalesx0(
lkfalse = [
k0 for k0, v0 in din.items()
if not (
(k0 in lkconst and type(v0) in _LTYPES)
or (k0 in lk and type(v0) in _LTYPES + [np.ndarray])
(k0 in lkconst and isinstance(v0, _NUMB))
or (k0 in lk and isinstance(v0, _NUMB + (np.ndarray,)))
or (
k0 in lkdict
and type(v0) in _LTYPES + [np.ndarray]
and isinstance(v0, _NUMB + (np.ndarray,))
or (
isinstance(v0, dict)
and all([
k1 in dinput[k0]['keys']
and type(v1) in _LTYPES + [np.ndarray]
and isinstance(v1, _NUMB + (np.ndarray,))
for k1, v1 in v0.items()
])
)
Expand Down Expand Up @@ -2585,15 +2592,15 @@ def _fit12d_filldef_dscalesx0_dict(

# Check vref
if vref is not None:
if type(vref) not in _LTYPES and len(vref) not in [1, nspect]:
if (not isinstance(vref, _NUMB)) and len(vref) not in [1, nspect]:
msg = (
"Non-conform vref for "
+ "{}['{}']\n".format(din_name, key)
+ "\t- expected: float or array (size {})\n".format(nspect)
+ "\t- provided: {}".format(vref)
)
raise Exception(msg)
if type(vref) in _LTYPES:
if isinstance(vref, _NUMB):
vref = np.full((nspect,), vref)
elif len(vref) == 1:
vref = np.full((nspect,), vref[0])
Expand All @@ -2604,10 +2611,10 @@ def _fit12d_filldef_dscalesx0_dict(
din[key] = {k0: vref for k0 in dinput[key]['keys']}

elif not isinstance(din[key], dict):
assert type(din[key]) in _LTYPES + [np.ndarray]
assert isinstance(din[key], _NUMB + (np.ndarray,))
if hasattr(din[key], '__len__') and len(din[key]) == 1:
din[key] = din[key][0]
if type(din[key]) in _LTYPES:
if isinstance(din[key], _NUMB):
din[key] = {
k0: np.full((nspect,), din[key])
for k0 in dinput[key]['keys']
Expand All @@ -2616,15 +2623,15 @@ def _fit12d_filldef_dscalesx0_dict(
din[key] = {k0: din[key] for k0 in dinput[key]['keys']}
else:
msg = (
"{}['{}'] not conform!".format(dd_name, key)
"{}['{}'] not conform!".format(din_name, key)
)
raise Exception(msg)

else:
for k0 in dinput[key]['keys']:
if din[key].get(k0) is None:
din[key][k0] = vref
elif type(din[key][k0]) in _LTYPES:
elif isinstance(din[key][k0], _NUMB):
din[key][k0] = np.full((nspect,), din[key][k0])
elif len(din[key][k0]) == 1:
din[key][k0] = np.full((nspect,), din[key][k0][0])
Expand All @@ -2645,7 +2652,7 @@ def _fit12d_filldef_dscalesx0_float(
nspect=None,
):
if din.get(key) is None:
if type(vref) in _LTYPES:
if isinstance(vref, _NUMB):
din[key] = np.full((nspect,), vref)
elif np.array(vref).shape == (1,):
din[key] = np.full((nspect,), vref[0])
Expand All @@ -2659,7 +2666,7 @@ def _fit12d_filldef_dscalesx0_float(
)
raise Exception(msg)
else:
if type(din[key]) in _LTYPES:
if isinstance(din[key], _NUMB):
din[key] = np.full((nspect,), din[key])
elif din[key].shape == (1,):
din[key] = np.full((nspect,), din[key][0])
Expand Down Expand Up @@ -2821,7 +2828,7 @@ def fit12d_dscales(dscales=None, dinput=None):
)[None, :]
dscales['amp'][key] = np.nanmax(data*conv, axis=1)
else:
if type(dscales['amp'][key]) in _LTYPES:
if isinstance(dscales['amp'][key], _NUMB):
dscales['amp'][key] = np.full((nspect,), dscales['amp'][key])
else:
assert dscales['amp'][key].shape == (nspect,)
Expand Down

0 comments on commit 958c486

Please sign in to comment.