Skip to content

Commit

Permalink
Prohibiting null character in String field
Browse files Browse the repository at this point in the history
  • Loading branch information
Mattias Loverot committed Jan 27, 2021
1 parent 4ee6510 commit c07c814
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 0 deletions.
8 changes: 8 additions & 0 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,14 @@ class String(Field):
"invalid_utf8": "Not a valid utf-8 string.",
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Insert validation into self.validators so that multiple errors can be stored.
validator = validate.ProhibitNullCharactersValidator(
error=self.error_messages["invalid"]
)
self.validators.insert(0, validator)

def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
if value is None:
return None
Expand Down
28 changes: 28 additions & 0 deletions src/marshmallow/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,34 @@ def __call__(self, value) -> typing.Any:
return value


class ProhibitNullCharactersValidator(Validator):
"""Validate string not having Null Character
:param error: Error message to raise in case of a validation error. Can be
interpolated with `{input}`.
"""

default_message = "String contains null character"

NULL_REGEX = re.compile(
r"\0",
)

def __init__(self, *, error: typing.Optional[str] = None):
self.error = error or self.default_message # type: str

def _format_error(self, value) -> typing.Any:
return self.error.format(input=value)

def __call__(self, value) -> typing.Any:
message = self._format_error(value)

if value and self.NULL_REGEX.search(str(value)):
raise ValidationError(message)

return value


class Range(Validator):
"""Validator which succeeds if the value passed to it is within the specified
range. If ``min`` is not specified, or is specified as `None`,
Expand Down
7 changes: 7 additions & 0 deletions tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,13 @@ class MySchema(Schema):
result = MySchema().dump({"name": "Monty", "foo": 42})
assert result == {"_NaMe": "Monty"}

def test_string_field_null_char(self):
class MySchema(Schema):
name = fields.String()

with pytest.raises(ValidationError):
MySchema().load({"name": "a\0b"})


class TestParentAndName:
class MySchema(Schema):
Expand Down

0 comments on commit c07c814

Please sign in to comment.