Skip to content

Commit

Permalink
feat: add PredicatesDef definition and validation
Browse files Browse the repository at this point in the history
  • Loading branch information
marcofavorito committed Jun 30, 2023
1 parent f90b899 commit ee69528
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 104 deletions.
8 changes: 5 additions & 3 deletions pddl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pddl.custom_types import namelike, parse_name, to_names, to_types # noqa: F401
from pddl.definitions.base import TypesDef
from pddl.definitions.constants_def import ConstantsDef
from pddl.definitions.predicates_def import _PredicatesDef
from pddl.helpers.base import assert_, check, ensure, ensure_set
from pddl.logic.base import And, Formula, is_literal
from pddl.logic.predicates import DerivedPredicate, Predicate
Expand Down Expand Up @@ -61,7 +62,9 @@ def __init__(
self._requirements = ensure_set(requirements)
self._types = TypesDef(types, self._requirements)
self._constants_def = ConstantsDef(self._requirements, self._types, constants)
self._predicates = ensure_set(predicates)
self._predicates_def = _PredicatesDef(
self._requirements, self._types, predicates
)
self._derived_predicates = ensure_set(derived_predicates)
self._actions = ensure_set(actions)

Expand All @@ -70,7 +73,6 @@ def __init__(
def _check_consistency(self) -> None:
"""Check consistency of a domain instance object."""
checker = TypeChecker(self._types, self.requirements)
checker.check_type(self._predicates)
checker.check_type(self._actions)
_check_types_in_has_terms_objects(self._actions, self._types.all_types) # type: ignore
self._check_types_in_derived_predicates()
Expand Down Expand Up @@ -102,7 +104,7 @@ def constants(self) -> AbstractSet[Constant]:
@property
def predicates(self) -> AbstractSet[Predicate]:
"""Get the predicates."""
return self._predicates
return self._predicates_def.predicates

@property
def derived_predicates(self) -> AbstractSet[DerivedPredicate]:
Expand Down
58 changes: 58 additions & 0 deletions pddl/definitions/predicates_def.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#
# Copyright 2021-2023 WhiteMech
#
# ------------------------------
#
# This file is part of pddl.
#
# Use of this source code is governed by an MIT-style
# license that can be found in the LICENSE file or at
# https://opensource.org/licenses/MIT.
#

"""This module implements the ConstantsDef class to handle the constants of a PDDL domain."""
from typing import AbstractSet, Collection, Dict, Optional

from pddl.custom_types import name as name_type
from pddl.definitions.base import TypesDef, _Definition
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import ensure_set
from pddl.logic.predicates import Predicate
from pddl.requirements import Requirements
from pddl.validation.terms import TermsValidator


class _PredicatesDef(_Definition):
"""A set of predicates of a PDDL domain."""

def __init__(
self,
requirements: AbstractSet[Requirements],
types: TypesDef,
predicates: Optional[Collection[Predicate]],
) -> None:
"""Initialize the PDDL constants section validator."""
super().__init__(requirements, types)

self._predicates: AbstractSet[Predicate] = ensure_set(predicates)

self._check_consistency()

@property
def predicates(self) -> AbstractSet[Predicate]:
"""Get the predicates."""
return self._predicates

def _check_consistency(self) -> None:
"""Check consistency of the predicates definition."""
seen_predicates_by_name: Dict[name_type, Predicate] = {}
for p in self._predicates:
# check that no two predicates have the same name
if p.name in seen_predicates_by_name:
raise PDDLValidationError(
f"these predicates have the same name: {p}, {seen_predicates_by_name[p.name]}"
)
seen_predicates_by_name[p.name] = p

# check that the terms of the predicate are consistent
TermsValidator(self._requirements, self._types).check_terms(p.terms)
70 changes: 16 additions & 54 deletions pddl/logic/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,95 +12,58 @@

"""This class implements PDDL predicates."""
import functools
from typing import Collection, Dict, Generator, Sequence, Set, Tuple
from typing import Sequence

from pddl.custom_types import name as name_type
from pddl.custom_types import namelike, parse_name
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import assert_, check
from pddl.helpers.base import assert_
from pddl.helpers.cache_hash import cache_hash
from pddl.logic.base import Atomic, Formula
from pddl.logic.terms import Constant, Term, _print_tag_set
from pddl.logic.terms import Constant, Term
from pddl.parser.symbols import Symbols
from pddl.validation.terms import TermsValidator


class _TermsList:
"""
A class wrapper for validating sequences of terms.
Note that this is only for internal validation of terms, and not specific to any PDDL domain type hierarchy.
"""

def __init__(self, terms_list: Collection[Term]) -> None:
"""Initialize the terms list."""
self._terms = tuple(self.check_no_duplicate_iterator(terms_list))
class _BaseAtomic(Atomic):
"""Base class to share common code among atomic formulas classes."""

def __init__(self, *terms: Term) -> None:
"""Initialize the atomic formula."""
TermsValidator.check_no_duplicates(terms)
self._terms = tuple(terms)
self._is_ground: bool = all(isinstance(v, Constant) for v in self._terms)

@property
def terms(self) -> Tuple[Term, ...]:
"""Get the terms sequence."""
def terms(self) -> Sequence[Term]:
"""Get the terms."""
return self._terms

@property
def is_ground(self) -> bool:
"""Check whether the predicate is ground."""
return self._is_ground

@staticmethod
def check_no_duplicate_iterator(
terms: Collection[Term],
) -> Generator[Term, None, None]:
"""
Iterate over terms and check that there are no duplicates.
In particular, terms with the same name must have the same type tags.
"""
seen: Dict[name_type, Set[name_type]] = {}
for term in terms:
if term.name not in seen:
seen[term.name] = set(term.type_tags)
else:
check(
seen[term.name] == set(term.type_tags),
f"Term {term} occurred twice with different type tags: "
f"previous type tags {_print_tag_set(seen[term.name])}, "
f"new type tags {_print_tag_set(term.type_tags)}",
exception_cls=PDDLValidationError,
)
yield term


@cache_hash
@functools.total_ordering
class Predicate(Atomic):
class Predicate(_BaseAtomic):
"""A class for a Predicate in PDDL."""

def __init__(self, predicate_name: namelike, *terms: Term):
"""Initialize the predicate."""
self._name = parse_name(predicate_name)
self._terms_list = _TermsList(terms)
super().__init__(*terms)

@property
def name(self) -> name_type:
"""Get the name."""
return self._name

@property
def terms(self) -> Sequence[Term]:
"""Get the terms."""
return self._terms_list.terms

@property
def arity(self) -> int:
"""Get the arity of the predicate."""
return len(self.terms)

@property
def is_ground(self) -> bool:
"""Check whether the predicate is ground."""
return self._terms_list.is_ground

# TODO check whether it's a good idea...
# TODO allow also for keyword-based replacement
# TODO allow skip replacement with None arguments.
Expand Down Expand Up @@ -143,7 +106,7 @@ def __lt__(self, other):
return super().__lt__(other)


class EqualTo(Atomic):
class EqualTo(_BaseAtomic):
"""Equality predicate."""

def __init__(self, left: Term, right: Term):
Expand All @@ -153,11 +116,10 @@ def __init__(self, left: Term, right: Term):
:param left: the left term.
:param right: the right term.
"""
super().__init__(left, right)
self._left = left
self._right = right

self._terms_list = _TermsList([self._left, self._right])

@property
def left(self) -> Term:
"""Get the left operand."""
Expand Down
6 changes: 3 additions & 3 deletions pddl/validation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#

"""Base module for validators."""
from typing import AbstractSet, Collection
from typing import AbstractSet, Callable, Collection

from pddl.custom_types import name as name_type
from pddl.definitions.base import TypesDef
Expand Down Expand Up @@ -44,10 +44,10 @@ def _check_typing_requirement(self, type_tags: Collection[name_type]) -> None:
)

def _check_types_are_available(
self, type_tags: Collection[name_type], what: str
self, type_tags: Collection[name_type], what: Callable[[], str]
) -> None:
"""Check that the types are available in the domain."""
if not self._types.all_types.issuperset(type_tags):
raise PDDLValidationError(
f"types {sorted(type_tags)} of {what} are not in available types {self._types.all_types}"
f"types {sorted(type_tags)} of {what()} are not in available types {sorted(self._types.all_types)}"
)
69 changes: 57 additions & 12 deletions pddl/validation/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,72 @@
#

"""Module for validator of terms."""
from typing import AbstractSet, Collection
from functools import partial
from typing import Collection, Dict, Generator, Set, Tuple, Type

from pddl.definitions.base import TypesDef
from pddl.logic.predicates import _TermsList
from pddl.logic.terms import Term
from pddl.requirements import Requirements
from pddl.custom_types import name as name_type
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import check
from pddl.logic.terms import Term, _print_tag_set
from pddl.validation.base import BaseValidator


class TermsValidator(BaseValidator):
"""Class for validator of terms."""

def __init__(
self, requirements: AbstractSet[Requirements], types: TypesDef
) -> None:
"""Initialize the validator."""
super().__init__(requirements, types)
@classmethod
def check_no_duplicates(cls, terms: Collection[Term]):
"""
Check that there are no duplicates.
This is the non-iterative version of '_check_no_duplicates_iterator'.
"""
# consume the iterator
list(cls._check_no_duplicates_iterator(terms))

@classmethod
def _check_no_duplicates_iterator(
cls, terms: Collection[Term]
) -> Generator[Term, None, None]:
"""
Iterate over terms and check that there are no duplicates.
In particular:
- terms with the same name must be of the same term type (variable or constant);
- terms with the same name must have the same type tags.
"""
_TypeSpec = Tuple[Type[Term], Set[name_type]]
seen: Dict[name_type, _TypeSpec] = {}
for term in terms:
if term.name not in seen:
seen[term.name] = (term.__class__, set(term.type_tags))
else:
expected_type, expected_type_tags = seen[term.name]
actual_type, actual_type_tags = term.__class__, set(term.type_tags)
check(
expected_type == actual_type,
f"Term {term} already occurred with type {expected_type}: got {actual_type}",
exception_cls=PDDLValidationError,
)
check(
expected_type_tags == actual_type_tags,
f"Term {term} occurred twice with different type tags: "
f"previous type tags {_print_tag_set(expected_type_tags)}, "
f"new type tags {_print_tag_set(actual_type_tags)}",
exception_cls=PDDLValidationError,
)
yield term

def check_terms(self, terms: Collection[Term]) -> None:
"""Check the terms."""
terms_iter = _TermsList.check_no_duplicate_iterator(terms)
terms_iter = self._check_no_duplicates_iterator(terms)
for term in terms_iter:
self._check_typing_requirement(term.type_tags)
self._check_types_are_available(term.type_tags, "terms")
self._check_types_are_available(
term.type_tags, partial(self._terms_to_string, terms)
)

@classmethod
def _terms_to_string(cls, terms: Collection[Term]) -> str:
"""Convert terms to string for error messages."""
return "terms ['" + "', '".join(map(str, terms)) + "']"
Loading

0 comments on commit ee69528

Please sign in to comment.