Skip to content

Commit

Permalink
Updated pydantic v1 -> v2.
Browse files Browse the repository at this point in the history
  • Loading branch information
coltonbh committed Aug 31, 2023
1 parent df5a0bd commit 0e3567f
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 206 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [unreleased]

### Changed

- 🚨 BREAKING CHANGE 🚨 Updated `pydantic` from `v1` -> `v2`.

## [0.4.2]

### Added
Expand Down
242 changes: 154 additions & 88 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8"
pydantic = ">=1.7.4,!=1.8,!=1.8.1,<2.0.0"
pydantic = ">=2.0.0"
numpy = ">=1.20"
toml = "^0.10.2"
pyyaml = "^6.0"
Expand All @@ -26,7 +26,7 @@ pre-commit = "^3.2.1"
pytest-cov = "^4.0.0"
ruff = "^0.0.260"
isort = "^5.12.0"
qcelemental = "^0.25.1"
qcelemental = ">=0.26.0"
types-toml = "^0.10.8.6"
types-pyyaml = "^6.0.12.10"

Expand Down
14 changes: 11 additions & 3 deletions qcio/helper_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from typing import List, Union

import numpy as np
from pydantic import PlainSerializer
from typing_extensions import Annotated

StrOrPath = Union[str, Path]
StrOrPath = Annotated[Union[str, Path], PlainSerializer(lambda x: str(x))]

# May be energy (float), gradient or hessian (List[List[float]])
SPReturnResult = Union[float, List[List[float]]]

# Type for any values that can be coerced to 2D numpy array
ArrayLike2D = Union[List[List[float]], List[float], np.ndarray]
ArrayLike2D = Annotated[
Union[List[List[float]], List[float], np.ndarray],
PlainSerializer(lambda x: np.array(x).tolist()),
]

# Type for any values that can be coerced to 3D numpy array
ArrayLike3D = Union[List[List[List[float]]], List[List[float]], np.ndarray]
ArrayLike3D = Annotated[
Union[List[List[List[float]]], List[List[float]], np.ndarray],
PlainSerializer(lambda x: np.array(x).tolist()),
]
115 changes: 45 additions & 70 deletions qcio/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import json
from abc import ABC
from base64 import b64decode, b64encode
from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

import numpy as np
import toml
import yaml
from pydantic import BaseModel, Extra, validator
from pydantic import BaseModel, field_serializer, field_validator
from typing_extensions import Self

from ..helper_types import StrOrPath
Expand All @@ -35,49 +34,17 @@ class QCIOModelBase(BaseModel, ABC):
# version: ClassVar[Literal["v1"]] = "v1"
extras: Dict[str, Any] = {}

class Config:
model_config = {
# Raises an error if extra fields are passed to model.
extra = Extra.forbid
"extra": "forbid",
# Allow numpy types in models. Pydantic will no longer raise an exception for
# types it doesn't recognize.
# https://docs.pydantic.dev/latest/usage/types/#arbitrary-types-allowed
arbitrary_types_allowed = True
"arbitrary_types_allowed": True,
# Don't allow mutation of objects
# https://docs.pydantic.dev/latest/usage/models/#faux-immutability
allow_mutation = False
# convert ndarray to list for JSON serialization
# https://docs.pydantic.dev/latest/usage/exporting_models/#json_encodershttps://docs.pydantic.dev/latest/usage/exporting_models/#json_encoders # noqa: E501
json_encoders = {np.ndarray: lambda v: v.tolist()}
# exclude fields with value None from serialization
exclude_none = True

def dict(self, **kwargs):
"""Convert the object to a dictionary.
Properly serialize numpy arrays. Serialization is performed in .dict() so that
multiple string serializers can used it without duplicating logic
(e.g. json, yaml, toml).
"""
model_dict = super().dict(**kwargs)
to_pop = []
for key, value in model_dict.items():
# Custom serialization for numpy arrays, enums, and pathlib Paths
if isinstance(value, np.ndarray):
model_dict[key] = value.tolist()
elif issubclass(type(value), Enum):
model_dict[key] = value.value
elif isinstance(value, Path):
model_dict[key] = str(value)

# Exclude empty lists, dictionaries, and objects with all None values from
# serialization
elif value in [None, [], {}]:
to_pop.append(key)

for key in to_pop:
model_dict.pop(key)

return model_dict
# https://docs.pydantic.dev/2.3/api/config/#pydantic.config.ConfigDict.frozen
"frozen": True,
}

@classmethod
def open(cls, filepath: Union[Path, str]) -> Self:
Expand All @@ -95,37 +62,42 @@ def open(cls, filepath: Union[Path, str]) -> Self:
data = filepath.read_text()

if filepath.suffix in [".yaml", ".yml"]:
return cls.parse_obj(yaml.safe_load(data))
return cls.model_validate(yaml.safe_load(data))
elif filepath.suffix == ".toml":
return cls.parse_obj(toml.loads(data))
return cls.model_validate(toml.loads(data))

# Assume json for all other file extensions
return cls.parse_raw(data)
# pydantic v2
# return cls.model_validate_json(filepath.read_text())
return cls.model_validate_json(data)

def save(self, filepath: Union[Path, str], **kwargs) -> None:
"""Write an object to disk as json.
def save(
self,
filepath: Union[Path, str],
exclude_none=True,
**kwargs,
) -> None:
"""Write an object to disk as json, yaml, or toml.
Args:
filepath: The path to write the object to.I
filepath: The path to write the object to.
exclude_none: If True, attributes with a value of None will not be written.
Changing default behavior from pydantic.model_dump() to True.
"""
filepath = Path(filepath)
filepath.parent.mkdir(exist_ok=True, parents=True)

model_dict = self.model_dump(mode="json", exclude_none=exclude_none, **kwargs)

if filepath.suffix in [".yaml", ".yml"]:
data = yaml.dump(self.dict(**kwargs))
data = yaml.dump(model_dict)

elif filepath.suffix == ".toml":
data = toml.dumps(self.dict(**kwargs))
data = toml.dumps(model_dict)

else:
# Write data to json regardless of file extension
data = self.json(**kwargs)
data = json.dumps(model_dict)

filepath.write_text(data)
# pydantic v2
# filepath.write(self.model_dump())

def __repr_args__(self) -> "ReprArgs":
"""Only show non empty fields in repr."""
Expand All @@ -139,6 +111,15 @@ def exists(value):
(name, value) for name, value in self.__dict__.items() if exists(value)
]

def __eq__(self, other: Any) -> bool:
"""Check equality of two objects.
Necessary because BaseModel.__eq__ does not compare numpy arrays.
"""
if isinstance(other, self.__class__):
return self.model_dump() == other.model_dump()
return False


class Files(QCIOModelBase):
"""File model for handling string and binary data.
Expand All @@ -151,29 +132,23 @@ class Files(QCIOModelBase):

files: Dict[str, Union[str, bytes]] = {}

@validator("files", pre=True)
@field_validator("files")
def convert_base64_to_bytes(cls, value):
"""Convert base64 encoded data to bytes."""
for filename, data in value.items():
if isinstance(data, str) and data.startswith("base64:"):
value[filename] = b64decode(data[7:])
return value

def dict(self, *args, **kwargs):
"""Return a dict representation of the object encoding bytes as b64 strings."""
dict = super().dict(*args, **kwargs)
if self.files: # clause so that empty files dict is not included in dict
files = {}
for filename, data in self.files.items():
if isinstance(data, bytes):
data = f"base64:{b64encode(data).decode('utf-8')}"
files[filename] = data
dict["files"] = files
return dict

def json(self, *args, **kwargs):
"""Return a JSON representation of the object."""
return json.dumps(self.dict(*args, **kwargs))
@field_serializer("files")
def serialize_files(self, files, _info) -> Dict[str, str]:
"""Serialize files to a dict of filename to base64 encoded string."""
return {
filename: f"base64:{b64encode(data).decode('utf-8')}"
if isinstance(data, bytes)
else data
for filename, data in files.items()
}

def add_file(
self, filepath: Union[Path, str], relative_dir: Optional[Path] = None
Expand Down Expand Up @@ -270,7 +245,7 @@ class Provenance(QCIOModelBase):

program: str
program_version: Optional[str] = None
scratch_dir: Optional[Path] = None
scratch_dir: Optional[StrOrPath] = None
wall_time: Optional[float] = None
hostname: Optional[str] = None
hostcpus: Optional[int] = None
Expand Down
7 changes: 6 additions & 1 deletion qcio/models/inputs_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import Any, Dict, Optional

from pydantic import BaseModel
from pydantic import BaseModel, field_serializer

from .base_models import Files
from .molecule import Molecule
Expand Down Expand Up @@ -103,3 +103,8 @@ class StructuredInputBase(ProgramArgs):

calctype: CalcType
molecule: Molecule

@field_serializer("calctype")
def serialize_calctype(self, calctype: CalcType, _info) -> str:
"""Serialize CalcType to string"""
return calctype.value
19 changes: 14 additions & 5 deletions qcio/models/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
from pydantic import validator
from pydantic import field_serializer, field_validator
from typing_extensions import Self

from qcio.constants import BOHR_TO_ANGSTROM
Expand Down Expand Up @@ -102,12 +102,21 @@ def __repr_args__(self) -> "ReprArgs":
("formula", self.formula),
]

@validator("geometry")
@field_validator("geometry")
def shape_n_by_3(cls, v, values, **kwargs):
"""Ensure there is an x, y, and z coordinate for each atom."""
n_atoms = len(values["symbols"])
n_atoms = len(values.data["symbols"])
return np.array(v).reshape(n_atoms, 3)

@field_serializer("connectivity")
def serialize_connectivity(self, connectivity, _info) -> List[List[float]]:
"""Serialize connectivity to a list of tuples.
Cannot have homogeneous data types in .toml files so must cast all values to
floats.
"""
return [[float(val) for val in bond] for bond in connectivity]

@property
def formula(self) -> str:
"""Return the molecular formula of the molecule using the Hill System.
Expand All @@ -132,9 +141,9 @@ def formula(self) -> str:
for element, count in sorted_elements
)

def dict(self, **kwargs) -> Dict[str, Any]:
def model_dump(self, **kwargs) -> Dict[str, Any]:
"""Handle tuple in connectivity"""
as_dict = super().dict(**kwargs)
as_dict = super().model_dump(**kwargs)
# Connectivity may be empty and super().dict() will remove empty values
if (connectivity := as_dict.get("connectivity")) is not None:
# Must cast all values to floats as toml cannot handle mixed types
Expand Down
15 changes: 8 additions & 7 deletions qcio/models/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import List, Literal, Optional, Union

import numpy as np
from pydantic import validator
from pydantic import field_validator

from qcio.helper_types import ArrayLike2D, ArrayLike3D

Expand Down Expand Up @@ -55,13 +55,14 @@ class Wavefunction(QCIOModelBase):
scf_occupations_a: Optional[ArrayLike2D] = None
scf_occupations_b: Optional[ArrayLike2D] = None

_to_numpy = validator(
@field_validator(
"scf_eigenvalues_a",
"scf_eigenvalues_b",
"scf_occupations_a",
"scf_occupations_b",
allow_reuse=True,
)(lambda x: np.asarray(x) if x is not None else None)
)
def to_numpy(cls, val, _info) -> Optional[np.ndarray]:
return np.asarray(val) if val is not None else None


class SinglePointResults(ResultsBase):
Expand Down Expand Up @@ -112,20 +113,20 @@ class SinglePointResults(ResultsBase):
normal_modes_cartesian: Optional[ArrayLike3D] = None
gibbs_free_energy: Optional[float] = None

@validator("normal_modes_cartesian")
@field_validator("normal_modes_cartesian")
def validate_normal_modes_cartesian_shape(cls, v: ArrayLike3D):
if v is not None:
# Assume array has length of the number of normal modes
n_normal_modes = len(v)
return np.asarray(v).reshape(n_normal_modes, -1, 3)

@validator("gradient")
@field_validator("gradient")
def validate_gradient_shape(cls, v: ArrayLike2D):
"""Validate gradient is n x 3"""
if v is not None:
return np.asarray(v).reshape(-1, 3)

@validator("hessian")
@field_validator("hessian")
def validate_hessian_shape(cls, v: ArrayLike2D):
"""Validate hessian is square"""
if v is not None:
Expand Down
7 changes: 4 additions & 3 deletions qcio/qcel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ def to_qcel_input(prog_input: ProgramInput) -> Dict[str, Any]:
"connectivity": prog_input.molecule.connectivity or None,
"identifiers": {
key: value
for key, value in prog_input.molecule.identifiers.dict().items()
if key not in ["name_IUPAC", "name_common"] # not on qcel model
for key, value in prog_input.molecule.identifiers.model_dump().items()
if key
not in ["name_IUPAC", "name_common", "extras"] # not on qcel model
},
},
"driver": prog_input.calctype,
"model": prog_input.model.dict(),
"model": prog_input.model.model_dump(),
"keywords": prog_input.keywords,
"extras": prog_input.extras,
}
Expand Down
6 changes: 2 additions & 4 deletions tests/test_base_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,10 @@ def test_file_b64(test_data_dir):
mixin = Files()
mixin.add_file(input_filepath)
assert isinstance(mixin.files[input_filepath.name], bytes)
json_str = mixin.json()
json_str = mixin.model_dump_json()
json_dict = json.loads(json_str)
assert json_dict["files"][input_filepath.name].startswith("base64:")
mixin_new = mixin.parse_raw(json_str)
# v2
# file = File.model_validate_json(json_dict["data"])
mixin_new = mixin.model_validate_json(json_str)
# Round trip of file is lossless
assert mixin_new.files["c0"] == input_filepath.read_bytes()

Expand Down
Loading

0 comments on commit 0e3567f

Please sign in to comment.