Skip to content

Commit

Permalink
feat: check predicates def only has variables
Browse files Browse the repository at this point in the history
  • Loading branch information
marcofavorito committed Jul 1, 2023
1 parent 8ef2a4b commit 8743f7b
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 44 deletions.
10 changes: 5 additions & 5 deletions pddl/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import functools
from collections.abc import Iterable
from typing import AbstractSet, Collection, Optional, Set, Tuple
from typing import AbstractSet, Collection, Optional, Tuple

from pddl.action import Action
from pddl.custom_types import name as name_type
Expand All @@ -40,7 +40,7 @@ def validate(condition: bool, message: str = "") -> None:


def _find_inconsistencies_in_typed_terms(
terms: Optional[Collection[Term]], all_types: Set[name_type]
terms: Optional[Collection[Term]], all_types: AbstractSet[name_type]
) -> Optional[Tuple[Term, name_type]]:
"""
Check that the terms in input all have legal types according to the list of available types.
Expand All @@ -60,7 +60,7 @@ def _find_inconsistencies_in_typed_terms(

def _check_types_in_has_terms_objects(
has_terms_objects: Optional[Collection[Predicate]],
all_types: Set[name_type],
all_types: AbstractSet[name_type],
) -> None:
"""Check that the terms in the set of predicates all have legal types."""
if has_terms_objects is None:
Expand All @@ -72,7 +72,7 @@ def _check_types_in_has_terms_objects(
term, type_tag = check_result
raise PDDLValidationError(
f"type {repr(type_tag)} of term {repr(term)} in atomic expression "
f"{repr(has_terms)} is not in available types {all_types}"
f"{repr(has_terms)} is not in available types {sorted(all_types)}"
)


Expand Down Expand Up @@ -104,7 +104,7 @@ def _check_types_are_available(
"""Check that the types are available in the domain."""
if not self._types.all_types.issuperset(type_tags):
raise PDDLValidationError(
f"types {sorted(type_tags)} of {what} are not in available types {self._types.all_types}"
f"types {sorted(type_tags)} of {what} are not in available types {sorted(self._types.all_types)}"
)

@functools.singledispatchmethod # type: ignore
Expand Down
4 changes: 2 additions & 2 deletions pddl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
It contains the class definitions to build and modify PDDL domains or problems.
"""
from typing import AbstractSet, Collection, Dict, Optional, Tuple, cast
from typing import AbstractSet, Collection, Dict, Mapping, Optional, Tuple, cast

from pddl._validation import TypeChecker, _check_types_in_has_terms_objects, validate
from pddl.action import Action
Expand Down Expand Up @@ -117,7 +117,7 @@ def actions(self) -> AbstractSet["Action"]:
return self._actions

@property
def types(self) -> Dict[name_type, Optional[name_type]]:
def types(self) -> Mapping[name_type, Optional[name_type]]:
"""Get the type definitions, if defined. Else, raise error."""
return self._types.raw

Expand Down
27 changes: 17 additions & 10 deletions pddl/definitions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#

"""Base module for the PDDL definitions."""
from typing import AbstractSet, Dict, Optional, Set, cast
from typing import AbstractSet, Dict, FrozenSet, Mapping, Optional, Set, cast

from pddl.custom_types import name as name_type
from pddl.custom_types import namelike, to_names, to_types # noqa: F401
Expand Down Expand Up @@ -45,13 +45,16 @@ def __init__(
self._all_types = self._get_all_types()
self._types_closure = self._compute_types_closure()

# only for printing purposes
self._sorted_all_types = sorted(self._all_types)

@property
def raw(self) -> Dict[name_type, Optional[name_type]]:
def raw(self) -> Mapping[name_type, Optional[name_type]]:
"""Get the raw types dictionary."""
return self._types

@property
def all_types(self) -> Set[name_type]:
def all_types(self) -> FrozenSet[name_type]:
"""Get all available types."""
return self._all_types

Expand All @@ -60,28 +63,32 @@ def is_subtype(self, type_a: name_type, type_b: name_type) -> bool:
# check whether type_a and type_b are legal types
error_msg = "type {0} is not in available types {1}"
if type_a not in self._all_types:
raise PDDLValidationError(error_msg.format(repr(type_a), self._all_types))
raise PDDLValidationError(
error_msg.format(repr(type_a), self._sorted_all_types)
)
if type_b not in self._all_types:
raise PDDLValidationError(error_msg.format(repr(type_b), self._all_types))
raise PDDLValidationError(
error_msg.format(repr(type_b), self._sorted_all_types)
)

return type_a in self._types_closure.get(type_b, set())

def _get_all_types(self) -> Set[name_type]:
def _get_all_types(self) -> FrozenSet[name_type]:
"""Get all types supported by the domain."""
if self._types is None:
return set()
return frozenset()
result = set(self._types.keys()) | set(self._types.values())
result.discard(None)
return cast(Set[name_type], result)
return cast(FrozenSet[name_type], frozenset(result))

def _compute_types_closure(self) -> Dict[name_type, Set[name_type]]:
def _compute_types_closure(self) -> Mapping[name_type, Set[name_type]]:
"""Compute the closure of the types dictionary."""
return transitive_closure(self._types)

@classmethod
def _check_types_dictionary(
cls,
type_dict: Dict[name_type, Optional[name_type]],
type_dict: Mapping[name_type, Optional[name_type]],
requirements: AbstractSet[Requirements],
) -> None:
"""
Expand Down
7 changes: 5 additions & 2 deletions pddl/definitions/predicates_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pddl.definitions.base import TypesDef, _Definition
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import ensure_set
from pddl.logic import Variable
from pddl.logic.predicates import Predicate
from pddl.requirements import Requirements
from pddl.validation.terms import TermsValidator
Expand Down Expand Up @@ -54,5 +55,7 @@ def _check_consistency(self) -> None:
)
seen_predicates_by_name[p.name] = p

# check that the terms of the predicate are consistent
TermsValidator(self._requirements, self._types).check_terms(p.terms)
# check that the terms are consistent wrt types, and that are all variables
TermsValidator(
self._requirements, self._types, must_be_instances_of=Variable
).check_terms(p.terms)
4 changes: 2 additions & 2 deletions pddl/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

"""Formatting utilities for PDDL domains and problems."""
from textwrap import indent
from typing import Callable, Collection, Dict, List, Optional
from typing import Callable, Collection, Dict, List, Mapping, Optional

from pddl.core import Domain, Problem
from pddl.custom_types import name
Expand Down Expand Up @@ -40,7 +40,7 @@ def _sort_and_print_collection(

def _print_types_with_parents(
prefix: str,
types_dict: Dict[name, Optional[name]],
types_dict: Mapping[name, Optional[name]],
postfix: str,
to_string: Callable = str,
):
Expand Down
12 changes: 11 additions & 1 deletion pddl/logic/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@

from pddl.custom_types import name as name_type
from pddl.custom_types import namelike, parse_name
from pddl.definitions.base import TypesDef
from pddl.helpers.base import assert_
from pddl.helpers.cache_hash import cache_hash
from pddl.logic.base import Atomic, Formula
from pddl.logic.terms import Constant, Term
from pddl.parser.symbols import Symbols
from pddl.requirements import Requirements
from pddl.validation.terms import TermsValidator


Expand All @@ -29,7 +31,7 @@ class _BaseAtomic(Atomic):

def __init__(self, *terms: Term) -> None:
"""Initialize the atomic formula."""
TermsValidator.check_terms_consistency(terms)
self._check_terms_light(terms)
self._terms = tuple(terms)
self._is_ground: bool = all(isinstance(v, Constant) for v in self._terms)

Expand All @@ -43,6 +45,14 @@ def is_ground(self) -> bool:
"""Check whether the predicate is ground."""
return self._is_ground

def _check_terms_light(self, terms: Sequence[Term]) -> None:
"""
Check the terms of the predicate, but only type tags consistency.
This method only performs checks that do not require external information (e.g. types provided by the domain).
"""
TermsValidator({Requirements.TYPING}, TypesDef()).check_terms_consistency(terms)


@cache_hash
@functools.total_ordering
Expand Down
82 changes: 62 additions & 20 deletions pddl/validation/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,47 @@

"""Module for validator of terms."""
from functools import partial
from typing import Collection, Dict, Generator, Set
from typing import AbstractSet, Collection, Dict, Generator, Optional, Set, Type, Union

from pddl.custom_types import name as name_type
from pddl.definitions.base import TypesDef
from pddl.exceptions import PDDLValidationError
from pddl.helpers.base import check
from pddl.logic.terms import Term, _print_tag_set
from pddl.logic.terms import Constant, Term, Variable, _print_tag_set
from pddl.requirements import Requirements
from pddl.validation.base import BaseValidator


class TermsValidator(BaseValidator):
"""Class for validator of terms."""
"""
Class for validator of terms.
@classmethod
def check_terms_consistency(cls, terms: Collection[Term]):
Some machinery is required to make the code as much as reusable as possible.
"""

def __init__(
self,
requirements: AbstractSet[Requirements],
types: TypesDef,
must_be_instances_of: Optional[Union[Type[Constant], Type[Variable]]] = None,
):
"""Initialize the validator."""
super().__init__(requirements, types)

# if none, then we don't care if constant or variable
self._allowed_superclass = must_be_instances_of

def check_terms_consistency(self, terms: Collection[Term]):
"""
Check that there are no duplicates.
This is the non-iterative version of '_check_terms_consistency_iterator'.
"""
# consume the iterator
list(cls._check_terms_consistency_iterator(terms))
list(self._check_terms_consistency_iterator(terms))

@classmethod
def _check_terms_consistency_iterator(
cls, terms: Collection[Term]
self, terms: Collection[Term]
) -> Generator[Term, None, None]:
"""
Iterate over terms and check that terms with the same name must have the same type tags.
Expand All @@ -47,20 +63,46 @@ def _check_terms_consistency_iterator(
"""
seen: Dict[name_type, Set[name_type]] = {}
for term in terms:
if term.name not in seen:
seen[term.name] = set(term.type_tags)
else:
expected_type_tags = seen[term.name]
actual_type_tags = set(term.type_tags)
check(
expected_type_tags == actual_type_tags,
f"Term {term} occurred twice with different type tags: "
f"previous type tags {_print_tag_set(expected_type_tags)}, "
f"new type tags {_print_tag_set(actual_type_tags)}",
exception_cls=PDDLValidationError,
)
self._check_same_term_has_same_type_tags(term, seen)
self._check_term_type(term, term_type=self._allowed_superclass)
yield term

@classmethod
def _check_same_term_has_same_type_tags(
cls, term: Term, seen: Dict[name_type, Set[name_type]]
) -> None:
"""
Check if the term has already been seen and, if so, that it has the same type tags.
This is an auxiliary method to simplify the implementation of '_check_terms_consistency_iterator'.
"""
if term.name not in seen:
seen[term.name] = set(term.type_tags)
else:
expected_type_tags = seen[term.name]
actual_type_tags = set(term.type_tags)
check(
expected_type_tags == actual_type_tags,
f"Term {term} occurred twice with different type tags: "
f"previous type tags {_print_tag_set(expected_type_tags)}, "
f"new type tags {_print_tag_set(actual_type_tags)}",
exception_cls=PDDLValidationError,
)

@classmethod
def _check_term_type(
cls,
term: Term,
term_type: Optional[Union[Type[Constant], Type[Variable]]] = None,
):
"""Check that the term is of the specified type."""
if term_type is not None:
check(
isinstance(term, term_type),
f"expected '{term}' being of type {term_type.__name__}; got {term.__class__.__name__} instead",
exception_cls=PDDLValidationError,
)

def check_terms(self, terms: Collection[Term]) -> None:
"""Check the terms."""
terms_iter = self._check_terms_consistency_iterator(terms)
Expand Down
17 changes: 15 additions & 2 deletions tests/test_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,19 @@ def test_predicate_variable_type_not_available() -> None:
Domain("test", requirements={Requirements.TYPING}, predicates={p}, types=type_set) # type: ignore


def test_predicate_def_with_some_constants_raise_error() -> None:
"""Test that when a predicate in the predicate def section contains constants we raise error (only vars allowed)."""
x = Variable("x")
c = Constant("c")
p = Predicate("p", x, c)

with pytest.raises(
PDDLValidationError,
match=r"expected 'c' being of type Variable; got Constant instead",
):
Domain("test", predicates={p})


def test_action_parameter_type_not_available() -> None:
"""Test that when a type of a action parameter is not declared we raise error."""
x = Variable("a", type_tags={"t1", "t2"})
Expand All @@ -162,7 +175,7 @@ def test_action_parameter_type_not_available() -> None:

with pytest.raises(
PDDLValidationError,
match=rf"types \['t1', 't2'\] of term {re.escape(repr(x))} are not in available types {{'{my_type}'}}",
match=rf"types \['t1', 't2'\] of term {re.escape(repr(x))} are not in available types \['my_type'\]",
):
Domain("test", requirements={Requirements.TYPING}, actions={action}, types=type_set) # type: ignore

Expand All @@ -179,6 +192,6 @@ def test_derived_predicate_type_not_available() -> None:
with pytest.raises(
PDDLValidationError,
match=rf"type '(t1|t2)' of term {re.escape(repr(x))} in atomic expression {re.escape(repr(p))} is not in "
f"available types {{'{my_type}'}}",
r"available types \['my_type'\]",
):
Domain("test", requirements={Requirements.TYPING}, derived_predicates={dp}, types=type_set) # type: ignore

0 comments on commit 8743f7b

Please sign in to comment.