-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add code generator for otel proto (#39)
- Loading branch information
1 parent
2736fad
commit 0fd842e
Showing
17 changed files
with
1,981 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# This workflow will delete and regenerate the opentelemetry marshaling code using scripts/proto_codegen.sh. | ||
# If generating the code produces any changes from what is currently checked in, the workflow will fail and prompt the user to regenerate the code. | ||
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions | ||
|
||
name: Check Codegen | ||
|
||
on: | ||
push: | ||
branches: [ "main" ] | ||
paths: | ||
- "scripts/**" | ||
- "src/snowflake/telemetry/_internal/opentelemetry/proto/**" | ||
- ".github/workflows/check-codegen.yml" | ||
pull_request: | ||
branches: [ "main" ] | ||
paths: | ||
- "scripts/**" | ||
- "src/snowflake/telemetry/_internal/opentelemetry/proto/**" | ||
- ".github/workflows/check-codegen.yml" | ||
|
||
jobs: | ||
check-codegen: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Set up Python | ||
uses: actions/setup-python@v3 | ||
with: | ||
python-version: "3.11" | ||
- name: Run codegen script | ||
run: | | ||
rm -rf src/snowflake/telemetry/_internal/opentelemetry/proto/ | ||
./scripts/proto_codegen.sh | ||
- name: Check for changes | ||
run: | | ||
git diff --exit-code || { echo "Code generation produced changes! Regenerate the code using ./scripts/proto_codegen.sh"; exit 1; } |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,237 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import os | ||
import sys | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional | ||
from enum import IntEnum | ||
|
||
from google.protobuf.compiler import plugin_pb2 as plugin | ||
from google.protobuf.descriptor_pb2 import ( | ||
FileDescriptorProto, | ||
FieldDescriptorProto, | ||
EnumDescriptorProto, | ||
EnumValueDescriptorProto, | ||
MethodDescriptorProto, | ||
ServiceDescriptorProto, | ||
DescriptorProto, | ||
) | ||
from jinja2 import Environment, FileSystemLoader | ||
import black | ||
import isort.api | ||
|
||
class WireType(IntEnum): | ||
VARINT = 0 | ||
I64 = 1 | ||
LEN = 2 | ||
I32 = 5 | ||
|
||
@dataclass | ||
class ProtoTypeDescriptor: | ||
name: str | ||
wire_type: WireType | ||
python_type: str | ||
|
||
proto_type_to_descriptor = { | ||
FieldDescriptorProto.TYPE_BOOL: ProtoTypeDescriptor("bool", WireType.VARINT, "bool"), | ||
FieldDescriptorProto.TYPE_ENUM: ProtoTypeDescriptor("enum", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_INT32: ProtoTypeDescriptor("int32", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_INT64: ProtoTypeDescriptor("int64", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_UINT32: ProtoTypeDescriptor("uint32", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_UINT64: ProtoTypeDescriptor("uint64", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_SINT32: ProtoTypeDescriptor("sint32", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_SINT64: ProtoTypeDescriptor("sint64", WireType.VARINT, "int"), | ||
FieldDescriptorProto.TYPE_FIXED32: ProtoTypeDescriptor("fixed32", WireType.I32, "int"), | ||
FieldDescriptorProto.TYPE_FIXED64: ProtoTypeDescriptor("fixed64", WireType.I64, "int"), | ||
FieldDescriptorProto.TYPE_SFIXED32: ProtoTypeDescriptor("sfixed32", WireType.I32, "int"), | ||
FieldDescriptorProto.TYPE_SFIXED64: ProtoTypeDescriptor("sfixed64", WireType.I64, "int"), | ||
FieldDescriptorProto.TYPE_FLOAT: ProtoTypeDescriptor("float", WireType.I32, "float"), | ||
FieldDescriptorProto.TYPE_DOUBLE: ProtoTypeDescriptor("double", WireType.I64, "float"), | ||
FieldDescriptorProto.TYPE_STRING: ProtoTypeDescriptor("string", WireType.LEN, "str"), | ||
FieldDescriptorProto.TYPE_BYTES: ProtoTypeDescriptor("bytes", WireType.LEN, "bytes"), | ||
FieldDescriptorProto.TYPE_MESSAGE: ProtoTypeDescriptor("message", WireType.LEN, "bytes"), | ||
} | ||
|
||
@dataclass | ||
class EnumValueTemplate: | ||
name: str | ||
number: int | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: EnumValueDescriptorProto) -> "EnumValueTemplate": | ||
return EnumValueTemplate( | ||
name=descriptor.name, | ||
number=descriptor.number, | ||
) | ||
|
||
@dataclass | ||
class EnumTemplate: | ||
name: str | ||
values: List["EnumValueTemplate"] = field(default_factory=list) | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: EnumDescriptorProto, parent: str = "") -> "EnumTemplate": | ||
return EnumTemplate( | ||
name=parent + "_" + descriptor.name if parent else descriptor.name, | ||
values=[EnumValueTemplate.from_descriptor(value) for value in descriptor.value], | ||
) | ||
|
||
def tag_to_repr_varint(tag: int) -> str: | ||
out = bytearray() | ||
while tag >= 128: | ||
out.append((tag & 0x7F) | 0x80) | ||
tag >>= 7 | ||
out.append(tag) | ||
return repr(bytes(out)) | ||
|
||
@dataclass | ||
class FieldTemplate: | ||
name: str | ||
number: int | ||
tag: str | ||
python_type: str | ||
proto_type: str | ||
repeated: bool | ||
group: str | ||
encode_presence: bool | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: FieldDescriptorProto, group: Optional[str] = None) -> "FieldTemplate": | ||
repeated = descriptor.label == FieldDescriptorProto.LABEL_REPEATED | ||
type_descriptor = proto_type_to_descriptor[descriptor.type] | ||
|
||
python_type = type_descriptor.python_type | ||
proto_type = type_descriptor.name | ||
|
||
if repeated: | ||
python_type = f"List[{python_type}]" | ||
proto_type = f"repeated_{proto_type}" | ||
|
||
tag = (descriptor.number << 3) | type_descriptor.wire_type.value | ||
if repeated and type_descriptor.wire_type != WireType.LEN: | ||
# Special case: repeated primitive fields are packed | ||
# So we need to use the length-delimited wire type | ||
tag = (descriptor.number << 3) | WireType.LEN.value | ||
# Convert the tag to a varint representation | ||
# Saves us from having to calculate the tag at runtime | ||
tag = tag_to_repr_varint(tag) | ||
|
||
# For group / oneof fields, we need to encode the presence of the field | ||
# For message fields, we need to encode the presence of the field if it is not None | ||
encode_presence = group is not None or proto_type == "message" | ||
|
||
return FieldTemplate( | ||
name=descriptor.name, | ||
tag=tag, | ||
number=descriptor.number, | ||
python_type=python_type, | ||
proto_type=proto_type, | ||
repeated=repeated, | ||
group=group, | ||
encode_presence=encode_presence, | ||
) | ||
|
||
@dataclass | ||
class MessageTemplate: | ||
name: str | ||
fields: List[FieldTemplate] = field(default_factory=list) | ||
enums: List["EnumTemplate"] = field(default_factory=list) | ||
messages: List["MessageTemplate"] = field(default_factory=list) | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: DescriptorProto, parent: str = "") -> "MessageTemplate": | ||
def get_group(field: FieldDescriptorProto) -> str: | ||
return descriptor.oneof_decl[field.oneof_index].name if field.HasField("oneof_index") else None | ||
fields = [FieldTemplate.from_descriptor(field, get_group(field)) for field in descriptor.field] | ||
fields.sort(key=lambda field: field.number) | ||
|
||
name = parent + "_" + descriptor.name if parent else descriptor.name | ||
return MessageTemplate( | ||
name=name, | ||
fields=fields, | ||
enums=[EnumTemplate.from_descriptor(enum, name) for enum in descriptor.enum_type], | ||
messages=[MessageTemplate.from_descriptor(message, name) for message in descriptor.nested_type], | ||
) | ||
|
||
@dataclass | ||
class MethodTemplate: | ||
name: str | ||
input_message: MessageTemplate | ||
output_message: MessageTemplate | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: MethodDescriptorProto) -> "MethodTemplate": | ||
return MethodTemplate( | ||
name=descriptor.name, | ||
input_message=MessageTemplate(name=descriptor.input_type), | ||
output_message=MessageTemplate(name=descriptor.output_type), | ||
) | ||
|
||
@dataclass | ||
class ServiceTemplate: | ||
name: str | ||
methods: List["MethodTemplate"] = field(default_factory=list) | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: ServiceDescriptorProto) -> "ServiceTemplate": | ||
return ServiceTemplate( | ||
name=descriptor.name, | ||
methods=[MethodTemplate.from_descriptor(method) for method in descriptor.method], | ||
) | ||
|
||
@dataclass | ||
class FileTemplate: | ||
messages: List["MessageTemplate"] = field(default_factory=list) | ||
enums: List["EnumTemplate"] = field(default_factory=list) | ||
services: List["ServiceTemplate"] = field(default_factory=list) | ||
name: str = "" | ||
|
||
@staticmethod | ||
def from_descriptor(descriptor: FileDescriptorProto) -> "FileTemplate": | ||
return FileTemplate( | ||
messages=[MessageTemplate.from_descriptor(message) for message in descriptor.message_type], | ||
enums=[EnumTemplate.from_descriptor(enum) for enum in descriptor.enum_type], | ||
services=[ServiceTemplate.from_descriptor(service) for service in descriptor.service], | ||
name=descriptor.name, | ||
) | ||
|
||
def main(): | ||
request = plugin.CodeGeneratorRequest() | ||
request.ParseFromString(sys.stdin.buffer.read()) | ||
|
||
response = plugin.CodeGeneratorResponse() | ||
# needed since metrics.proto uses proto3 optional fields | ||
response.supported_features = plugin.CodeGeneratorResponse.FEATURE_PROTO3_OPTIONAL | ||
|
||
template_env = Environment(loader=FileSystemLoader(f"{os.path.dirname(os.path.realpath(__file__))}/templates")) | ||
jinja_body_template = template_env.get_template("template.py.jinja2") | ||
|
||
for proto_file in request.proto_file: | ||
file_name = proto_file.name.replace('.proto', '.py') | ||
file_descriptor_proto = proto_file | ||
|
||
file_template = FileTemplate.from_descriptor(file_descriptor_proto) | ||
|
||
code = jinja_body_template.render(file_template=file_template) | ||
code = isort.api.sort_code_string( | ||
code = code, | ||
show_diff=False, | ||
profile="black", | ||
combine_as_imports=True, | ||
lines_after_imports=2, | ||
quiet=True, | ||
force_grid_wrap=2, | ||
) | ||
code = black.format_str( | ||
src_contents=code, | ||
mode=black.Mode(), | ||
) | ||
|
||
response_file = response.file.add() | ||
response_file.name = file_name | ||
response_file.content = code | ||
|
||
sys.stdout.buffer.write(response.SerializeToString()) | ||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/bin/bash | ||
# | ||
# Regenerate python code from OTLP protos in | ||
# https://github.com/open-telemetry/opentelemetry-proto | ||
# | ||
# To use, update PROTO_REPO_BRANCH_OR_COMMIT variable below to a commit hash or | ||
# tag in opentelemtry-proto repo that you want to build off of. Then, just run | ||
# this script to update the proto files. Commit the changes as well as any | ||
# fixes needed in the OTLP exporter. | ||
# | ||
# Optional envars: | ||
# PROTO_REPO_DIR - the path to an existing checkout of the opentelemetry-proto repo | ||
|
||
# Pinned commit/branch/tag for the current version used in opentelemetry-proto python package. | ||
PROTO_REPO_BRANCH_OR_COMMIT="v1.2.0" | ||
|
||
set -e | ||
|
||
PROTO_REPO_DIR=${PROTO_REPO_DIR:-"/tmp/opentelemetry-proto"} | ||
# root of opentelemetry-python repo | ||
repo_root="$(git rev-parse --show-toplevel)" | ||
venv_dir="/tmp/proto_codegen_venv" | ||
|
||
# run on exit even if crash | ||
cleanup() { | ||
echo "Deleting $venv_dir" | ||
rm -rf $venv_dir | ||
} | ||
trap cleanup EXIT | ||
|
||
echo "Creating temporary virtualenv at $venv_dir using $(python3 --version)" | ||
python3 -m venv $venv_dir | ||
source $venv_dir/bin/activate | ||
python -m pip install protobuf Jinja2 grpcio-tools black isort | ||
echo 'python -m grpc_tools.protoc --version' | ||
python -m grpc_tools.protoc --version | ||
|
||
# Clone the proto repo if it doesn't exist | ||
if [ ! -d "$PROTO_REPO_DIR" ]; then | ||
git clone https://github.com/open-telemetry/opentelemetry-proto.git $PROTO_REPO_DIR | ||
fi | ||
|
||
# Pull in changes and switch to requested branch | ||
( | ||
cd $PROTO_REPO_DIR | ||
git fetch --all | ||
git checkout $PROTO_REPO_BRANCH_OR_COMMIT | ||
# pull if PROTO_REPO_BRANCH_OR_COMMIT is not a detached head | ||
git symbolic-ref -q HEAD && git pull --ff-only || true | ||
) | ||
|
||
cd $repo_root/src/snowflake/telemetry/_internal | ||
|
||
# clean up old generated code | ||
mkdir -p opentelemetry/proto | ||
find opentelemetry/proto/ -regex ".*\.py?" -exec rm {} + | ||
|
||
# generate proto code for all protos | ||
all_protos=$(find $PROTO_REPO_DIR/ -iname "*.proto") | ||
python -m grpc_tools.protoc \ | ||
-I $PROTO_REPO_DIR \ | ||
--plugin=protoc-gen-custom-plugin=$repo_root/scripts/plugin.py \ | ||
--custom-plugin_out=. \ | ||
$all_protos |
Oops, something went wrong.