Skip to content

Commit

Permalink
Fix bugs with dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
AngelFP committed Sep 27, 2023
1 parent a370276 commit d96813a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 15 additions & 1 deletion optimas/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,18 @@
This module defines a base class for all classes that have a name attribute.
Examples of these are the different optimization parameters and tasks.
"""
import json

from pydantic import BaseModel
from pydantic import BaseModel, Extra
import numpy as np


def json_dumps_dtype(v, *, default):
"""Add support for dumping numpy dtype to json."""
for key, value in v.items():
if key == 'dtype':
v[key] = np.dtype(value).descr
return json.dumps(v)


class NamedBase(BaseModel):
Expand All @@ -22,3 +32,7 @@ def __init__(
**kwargs
) -> None:
super().__init__(name=name, **kwargs)

class Config:
extra = Extra.ignore
json_dumps = json_dumps_dtype
14 changes: 13 additions & 1 deletion optimas/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from typing import Optional, Any

from pydantic import validator
import numpy as np

from .base import NamedBase


Expand All @@ -16,7 +19,7 @@ class Parameter(NamedBase):
The data type of the parameter. Any object that can be converted to a
numpy dtype.
"""
dtype: Optional[Any] = float
dtype: Optional[Any]

def __init__(
self,
Expand All @@ -26,6 +29,15 @@ def __init__(
) -> None:
super().__init__(name=name, dtype=dtype, **kwargs)

@validator("dtype", pre=True)
def check_valid_out(cls, v):
try:
_ = np.dtype(v)
except TypeError:
raise ValueError(f"Unable to coerce '{v}' into a NumPy dtype.")
else:
return v


class VaryingParameter(Parameter):
"""Defines an input parameter to be varied during optimization.
Expand Down

0 comments on commit d96813a

Please sign in to comment.