Skip to content

Commit

Permalink
Inline init
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-jopel committed Nov 12, 2024
1 parent 1a8d826 commit 0378493
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
19 changes: 17 additions & 2 deletions scripts/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def inline_size_function(proto_type: str, field_name: str, field_tag: str) -> st
function_definition = dedent(function_definition)
# Replace the field name
function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}")
function_definition = function_definition.replace("CACHED_FIELD", field_name)
# Replace the TAG
function_definition = function_definition.replace("TAG", field_tag)
# Inline the return statement
Expand All @@ -53,11 +52,19 @@ def inline_serialize_function(proto_type: str, field_name: str, field_tag: str)
function_definition = dedent(function_definition)
# Replace the field name
function_definition = function_definition.replace("FIELD_ATTR", f"self.{field_name}")
function_definition = function_definition.replace("CACHED_FIELD", field_name)
# Replace the TAG
function_definition = function_definition.replace("TAG", field_tag)
return function_definition

# Inline the init function for a proto message
def inline_init() -> str:
function_definition = inspect.getsource(globals()["MessageMarshaler"].__dict__["__init__"])
# Remove the function header and unindent the function body
function_definition = function_definition.splitlines()[1:]
function_definition = "\n".join(function_definition)
function_definition = dedent(function_definition)
return function_definition

# Add a presence check to a function definition
# https://protobuf.dev/programming-guides/proto3/#default
def add_presence_check(proto_type: str, encode_presence: bool, field_name: str, function_definition: str) -> str:
Expand Down Expand Up @@ -238,6 +245,7 @@ def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = Non
@dataclass
class MessageTemplate:
name: str
super_class_init: str
fields: List[FieldTemplate] = field(default_factory=list)
enums: List["EnumTemplate"] = field(default_factory=list)
messages: List["MessageTemplate"] = field(default_factory=list)
Expand All @@ -251,9 +259,16 @@ def get_group(field: FieldDescriptorProto) -> str:
fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field]
fields.sort(key=lambda field: field.number)

# Inline the superclass MessageMarshaler init function
if INLINE_OPTIMIZATION:
super_class_init = inline_init()
else:
super_class_init = "super().__init__()"

name = descriptor.name
return MessageTemplate(
name=name,
super_class_init=super_class_init,
fields=fields,
enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type],
messages=[MessageTemplate.from_descriptor(message) for message in descriptor.nested_type],
Expand Down
2 changes: 1 addition & 1 deletion scripts/templates/template.py.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class {{ message.name }}(MessageMarshaler):
{%- for field in message.fields %}
self.{{ field.attr_name }}: {{ field.python_type }} = {{ field.name }}
{%- endfor %}
self._marshaler_cache = {}
{{ message.super_class_init | indent(8) }}

def calculate_size(self) -> int:
size = 0
Expand Down
12 changes: 7 additions & 5 deletions src/snowflake/telemetry/_internal/serialize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,11 @@ def write_varint_s32(out: bytearray, value: int) -> None:

# Base class for all custom messages
class MessageMarshaler:
_marshaler_cache: Dict[bytes, Any]

# Init may be inlined by the code generator
def __init__(self) -> None:
self._marshaler_cache: Dict[bytes, Any] = {}
self._marshaler_cache = {}

def write_to(self, out: bytearray) -> None:
...
Expand All @@ -102,7 +105,7 @@ def _get_size(self) -> int:
return self._size

def SerializeToString(self) -> bytes:
# size MUST be calculated before serializing since some preprocessing is done here
# size MUST be calculated before serializing since some preprocessing is done
self._get_size()
stream = bytearray()
self.write_to(stream)
Expand All @@ -111,11 +114,10 @@ def SerializeToString(self) -> bytes:
def __bytes__(self) -> bytes:
return self.SerializeToString()

# THE FOLLOWING FUNCTIONS CAN BE INLINED BY THE CODE GENERATOR
# The following strings are string replaced by the code generator if inlining:
# The following size and serialize functions may be inlined by the code generator
# The following strings are replaced by the code generator for inlining:
# - TAG
# - FIELD_ATTR
# - CACHED_FIELD

def size_bool(self, TAG: bytes, _) -> int:
return len(TAG) + 1
Expand Down

0 comments on commit 0378493

Please sign in to comment.