From 0378493e3047d7fce424b6a7371f53cbc8735e86 Mon Sep 17 00:00:00 2001 From: Jeevan Opel Date: Tue, 12 Nov 2024 09:33:48 -0800 Subject: [PATCH] Inline init --- scripts/plugin.py | 19 +++++++++++++++++-- scripts/templates/template.py.jinja2 | 2 +- .../telemetry/_internal/serialize/__init__.py | 12 +++++++----- 3 files changed, 25 insertions(+), 8 deletions(-) diff --git a/scripts/plugin.py b/scripts/plugin.py index acb27fa..dcef463 100755 --- a/scripts/plugin.py +++ b/scripts/plugin.py @@ -37,7 +37,6 @@ def inline_size_function(proto_type: str, field_name: str, field_tag: str) -> st function_definition = dedent(function_definition) # Replace the field name function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}") - function_definition = function_definition.replace("CACHED_FIELD", field_name) # Replace the TAG function_definition = function_definition.replace("TAG", field_tag) # Inline the return statement @@ -53,11 +52,19 @@ def inline_serialize_function(proto_type: str, field_name: str, field_tag: str) function_definition = dedent(function_definition) # Replace the field name function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}") - function_definition = function_definition.replace("CACHED_FIELD", field_name) # Replace the TAG function_definition = function_definition.replace("TAG", field_tag) return function_definition +# Inline the init function for a proto message +def inline_init() -> str: + function_definition = inspect.getsource(globals()["MessageMarshaler"].__dict__["__init__"]) + # Remove the function header and unindent the function body + function_definition = function_definition.splitlines()[1:] + function_definition = "\n".join(function_definition) + function_definition = dedent(function_definition) + return function_definition + # Add a presence check to a function definition # https://protobuf.dev/programming-guides/proto3/#default def add_presence_check(proto_type: str, encode_presence: bool, field_name: str, function_definition: str) -> str: @@ -238,6 +245,7 @@ def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = Non @dataclass class MessageTemplate: name: str + super_class_init: str fields: List[FieldTemplate] = field(default_factory=list) enums: List["EnumTemplate"] = field(default_factory=list) messages: List["MessageTemplate"] = field(default_factory=list) @@ -251,9 +259,16 @@ def get_group(field: FieldDescriptorProto) -> str: fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field] fields.sort(key=lambda field: field.number) + # Inline the superclass MessageMarshaler init function + if INLINE_OPTIMIZATION: + super_class_init = inline_init() + else: + super_class_init = "super().__init__()" + name = descriptor.name return MessageTemplate( name=name, + super_class_init=super_class_init, fields=fields, enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type], messages=[MessageTemplate.from_descriptor(message) for message in descriptor.nested_type], diff --git a/scripts/templates/template.py.jinja2 b/scripts/templates/template.py.jinja2 index 1f51495..c4242e8 100644 --- a/scripts/templates/template.py.jinja2 +++ b/scripts/templates/template.py.jinja2 @@ -47,7 +47,7 @@ class {{ message.name }}(MessageMarshaler): {%- for field in message.fields %} self.{{ field.attr_name }}: {{ field.python_type }} = {{ field.name }} {%- endfor %} - self._marshaler_cache = {} + {{ message.super_class_init | indent(8) }} def calculate_size(self) -> int: size = 0 diff --git a/src/snowflake/telemetry/_internal/serialize/__init__.py b/src/snowflake/telemetry/_internal/serialize/__init__.py index ae165fa..42f3c82 100644 --- a/src/snowflake/telemetry/_internal/serialize/__init__.py +++ b/src/snowflake/telemetry/_internal/serialize/__init__.py @@ -87,8 +87,11 @@ def write_varint_s32(out: bytearray, value: int) -> None: # Base class for all custom messages class MessageMarshaler: + _marshaler_cache: Dict[bytes, Any] + + # Init may be inlined by the code generator def __init__(self) -> None: - self._marshaler_cache: Dict[bytes, Any] = {} + self._marshaler_cache = {} def write_to(self, out: bytearray) -> None: ... @@ -102,7 +105,7 @@ def _get_size(self) -> int: return self._size def SerializeToString(self) -> bytes: - # size MUST be calculated before serializing since some preprocessing is done here + # size MUST be calculated before serializing since some preprocessing is done self._get_size() stream = bytearray() self.write_to(stream) @@ -111,11 +114,10 @@ def SerializeToString(self) -> bytes: def __bytes__(self) -> bytes: return self.SerializeToString() - # THE FOLLOWING FUNCTIONS CAN BE INLINED BY THE CODE GENERATOR - # The following strings are string replaced by the code generator if inlining: + # The following size and serialize functions may be inlined by the code generator + # The following strings are replaced by the code generator for inlining: # - TAG # - FIELD_ATTR - # - CACHED_FIELD def size_bool(self, TAG: bytes, _) -> int: return len(TAG) + 1