-
Notifications
You must be signed in to change notification settings - Fork 112
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
602010d
commit 1a292e3
Showing
6 changed files
with
337 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
syntax = "proto3"; | ||
|
||
option cc_generic_services = false; | ||
option java_outer_classname = "SnowparkTcmProto"; | ||
option java_package = "com.snowflake.snowpark.proto.tcm"; | ||
|
||
import "google/protobuf/descriptor.proto"; | ||
|
||
package tcm; | ||
|
||
enum TcmPayloadType { | ||
SNOWPARK_PYTHON_AST = 0; | ||
// ... future extensions, e.g., Spark API, pandas API, etc. | ||
} | ||
|
||
// The TCM request | ||
message TcmRequest { | ||
// The payload of the request, e.g., the Dataframe AST request. | ||
bytes payload = 1; | ||
// The type of the payload | ||
TcmPayloadType type = 2; | ||
// ID of the request | ||
int64 request_id = 3; | ||
// ... future extensions, e.g., telemetry, metrics config. | ||
} | ||
|
||
// Result for the TCM request | ||
message TcmResponse { | ||
// The types of TCM errors | ||
enum ErrorCode { | ||
// TCM Failed to initialize | ||
TCM_INIT_ERROR = 0; | ||
// TCM Failed to execute the request | ||
TCM_EXECUTION_ERROR = 1; | ||
// ... other error codes | ||
} | ||
|
||
// Indicates that the request has failed | ||
message Error { | ||
ErrorCode code = 1; | ||
string message = 2; | ||
} | ||
|
||
// The null value | ||
enum NullValue { | ||
NULL_VALUE = 0; | ||
} | ||
|
||
// The non-query result | ||
message OkResult { | ||
oneof value { | ||
// The query UUID in string format | ||
string query_uuid = 1; | ||
// The string value | ||
string val_str = 2; | ||
// The int64 value | ||
int64 val_int = 3; | ||
// The double value | ||
double val_double = 4; | ||
// The boolean value | ||
bool val_bool = 5; | ||
// The null value | ||
NullValue val_null = 6; | ||
// ... future extensions, e.g., map, struct, etc. | ||
} | ||
} | ||
|
||
// The payload of the response, e.g., the Dataframe AST response. | ||
bytes payload = 1; | ||
|
||
// The type of the payload | ||
TcmPayloadType type = 2; | ||
|
||
// ID of the request with which this response is associated. | ||
int64 request_id = 3; | ||
|
||
// The result of the TCM request. It is either an error status, a query uuid, or a non-query result. | ||
oneof result { | ||
// The error status | ||
Error error = 4; | ||
// The okay result | ||
OkResult okResult = 5; | ||
// ... future extensions, e.g., telemetry, metrics, log, etc. | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
import logging | ||
import snowflake.snowpark | ||
import snowflake.snowpark._internal.proto.generated.ast_pb2 as ast | ||
import snowflake.snowpark._internal.proto.generated.SnowparkTcm_pb2 as tcm_proto | ||
from snowflake.snowpark._internal.tcm.sp_ast_decoder import SnowparkAstDecoder | ||
import base64 | ||
from google.protobuf import message | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def str2proto(b64_input: str, proto_output: message.Message) -> None: | ||
decoded = base64.b64decode(b64_input) | ||
proto_output.ParseFromString(decoded) | ||
|
||
|
||
def proto2str(proto_input: message.Message) -> str: | ||
return str(base64.b64encode(proto_input.SerializeToString()), "utf-8") | ||
|
||
|
||
class TcmSession: | ||
def __init__(self, session: snowflake.snowpark.Session): | ||
""" | ||
Initializes TCM with optional Snowpark session to connect to. | ||
Args: | ||
session: optional, if None automatically retrieves parameters. | ||
""" | ||
self._session = session | ||
self._decoder = None | ||
|
||
def get_decoder(self, type): | ||
if type != tcm_proto.SNOWPARK_PYTHON_AST: | ||
raise NotImplementedError | ||
if self._decoder is None: | ||
self._decoder = SnowparkAstDecoder(self._session) | ||
return self._decoder | ||
|
||
def construct_tcm_response_from_ast(self, ast_res_proto, type, rid): | ||
uuid = ast_res_proto.body[0].eval_ok.data.sf_query_result.uuid | ||
|
||
return tcm_proto.TcmResponse(payload=ast_res_proto.SerializeToString(), type=type, | ||
request_id=rid, | ||
okResult=tcm_proto.TcmResponse.OkResult(query_uuid=uuid)) | ||
|
||
def request(self, tcm_req_base64: str) -> str: | ||
try: | ||
tcm_req_proto = tcm_proto.TcmRequest() | ||
str2proto(tcm_req_base64, tcm_req_proto) | ||
rid = tcm_req_proto.request_id | ||
logging.debug(f"request id: {rid}") | ||
|
||
assert tcm_req_proto.type == tcm_proto.SNOWPARK_PYTHON_AST | ||
|
||
decoder = self.get_decoder(tcm_req_proto.type) | ||
|
||
ast_req_proto = ast.Request() | ||
ast_req_proto.ParseFromString(tcm_req_proto.payload) | ||
|
||
ast_res_proto = decoder.request(ast_req_proto) | ||
|
||
tcm_res_proto = self.construct_tcm_response_from_ast(ast_res_proto, tcm_req_proto.type, rid) | ||
return proto2str(tcm_res_proto) | ||
except Exception: | ||
raise | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import Any | ||
|
||
import snowflake | ||
import logging | ||
import snowflake.snowpark._internal.proto.generated.ast_pb2 as ast | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def column_expr(e): | ||
variant = e.WhichOneof('variant') | ||
if variant == 'sp_column_sql_expr': | ||
e = e.sp_column_sql_expr | ||
return e.sql | ||
else: | ||
logger.warning('Unexpected column expr %s', str(e)) | ||
return None | ||
|
||
|
||
class SnowparkAstDecoder: | ||
def __init__(self, session: snowflake.snowpark.Session): | ||
self._session = session | ||
self._session.ast_enabled = True | ||
self.bindings = {} | ||
|
||
def request(self, req: ast.Request) -> ast.Response: | ||
resp = ast.Response() | ||
|
||
for stmt in req.body: | ||
variant = stmt.WhichOneof('variant') | ||
if variant == 'assign': | ||
self.assign(stmt.assign) | ||
elif variant == 'eval': | ||
ans = self.eval(stmt.eval) | ||
|
||
# Fill into response evaluation result, this allows on the client-side to reconstruct | ||
# values with results. | ||
ok_result = ast.EvalOk(uid=stmt.eval.uid, var_id=stmt.eval.var_id, data=ans) | ||
res = ast.Result() | ||
res.eval_ok.CopyFrom(ok_result) | ||
resp.body.add().CopyFrom(res) | ||
else: | ||
logger.warning('Unexpected statement %s', str(stmt)) | ||
logger.info('Session bindings %s', str(self.bindings)) | ||
return resp | ||
|
||
def get_binding(self, var_id): | ||
# TODO: check if valid. | ||
return self.bindings[var_id.bitfield1] | ||
|
||
def assign(self, assign) -> None: | ||
val = self.expr(assign.expr) | ||
self.bindings[assign.var_id.bitfield1] = val | ||
|
||
def eval(self, eval) -> Any: | ||
res = self.get_binding(eval.var_id) | ||
logger.info('Return atom %s := %s', eval.var_id.bitfield1, str(res)) | ||
return res | ||
|
||
def expr(self, e): | ||
variant = e.WhichOneof('variant') | ||
if variant == 'sp_dataframe_ref': | ||
e = e.sp_dataframe_ref | ||
return self.get_binding(e.id) | ||
elif variant == 'sp_sql': | ||
e = e.sp_sql | ||
return self._session.sql(e.query) | ||
elif variant == 'sp_table': | ||
e = e.sp_table | ||
return self._session.table(e.table) | ||
elif variant == 'sp_dataframe_filter': | ||
e = e.sp_dataframe_filter | ||
df = self.expr(e.df) | ||
condition = column_expr(e.condition) | ||
return df.filter(condition) | ||
elif variant == 'sp_dataframe_show': | ||
e = e.sp_dataframe_show | ||
df = self.get_binding(e.id) | ||
job = df.collect_nowait() | ||
return job.query_id | ||
elif variant == 'sp_dataframe_collect': | ||
e = e.sp_dataframe_collect | ||
df = self.get_binding(e.id) | ||
job = df.collect_nowait() | ||
return ast.EvalResult(sf_query_result=ast.SfQueryResult(uuid=job.query_id)) | ||
else: | ||
logger.warning('Unexpected expr %s', str(e)) | ||
return None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Test in this file the TCM. | ||
# If run on workspace/devvm with parameters configured, this should internally call a running GS/XP instance. | ||
from snowflake.snowpark._internal.ast_utils import base64_lines_to_textproto | ||
from snowflake.snowpark._internal.tcm.session import TcmSession, str2proto, proto2str | ||
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto | ||
import snowflake.snowpark._internal.proto.generated.SnowparkTcm_pb2 as tcm_proto | ||
from google.protobuf.json_format import MessageToDict | ||
import base64 | ||
import json | ||
|
||
from snowflake.snowpark.functions import col | ||
|
||
|
||
def test_ast_gen(session): | ||
tcm = TcmSession(session) | ||
session.ast_enabled = True | ||
# Reset the entity ID generator. | ||
session._ast_batch.reset_id_gen() | ||
|
||
session._ast_batch.flush() # Clear the AST. | ||
# Run the test. | ||
with session.ast_listener() as al: | ||
print(session.sql("select ln(3)").collect()) | ||
|
||
# Retrieve the ASTs corresponding to the test. | ||
df_ast = al.base64_batches | ||
# if last_batch: | ||
# result.append(last_batch) | ||
|
||
df_ast = "\n".join(df_ast) | ||
print(df_ast) | ||
print(f"len:{len(df_ast)}") | ||
|
||
# GS will receive df_ast | ||
|
||
# GS build tcm request | ||
ast_binary = base64.b64decode(df_ast) | ||
print(f"bin len: {len(ast_binary)}") | ||
|
||
print(base64_lines_to_textproto(df_ast.strip())) | ||
# | ||
# | ||
# req_proto = proto.Request() | ||
# req_proto.ParseFromString(req_binary) | ||
tcm_request = tcm_proto.TcmRequest(payload=ast_binary, type=tcm_proto.SNOWPARK_PYTHON_AST, request_id=1) | ||
tcm_req_str = str(base64.b64encode(tcm_request.SerializeToString()), "utf-8") | ||
print("tcm request:") | ||
|
||
print(json.dumps(MessageToDict(tcm_request), indent=2)) | ||
print("tcm request str:") | ||
print(tcm_req_str) | ||
print("tcm request str end:") | ||
tcm_res_str = tcm.request(tcm_req_str) | ||
print(tcm_res_str) | ||
tcm_res = tcm_proto.TcmResponse() | ||
|
||
str2proto(tcm_res_str, tcm_res) | ||
|
||
ans = MessageToDict(tcm_res) | ||
|
||
print(json.dumps(ans, indent=2)) | ||
|
||
ast_res = proto.Response() | ||
ast_res.ParseFromString(tcm_res.payload) | ||
ans = MessageToDict(ast_res) | ||
|
||
print(json.dumps(ans, indent=2)) | ||
|
||
|
||
def test_tcm_eval(): | ||
tcm = TcmSession() | ||
|
||
req_base64 = "'CgIIKhIICgYKBAgDEAgacwpxCAESAggBGmOKCmAKUgpOL1VzZXJzL2xzcGllZ2VsYmVyZy9wcm9qZWN0cy9zbm93cGFyay1weXRob24vdGVzdHMvdGhpbi1jbGllbnQvc3RlZWwtdGhyZWFkLnB5EDciCnRlc3RfdGFibGUiBAoCZGYaEwoRCAISAggCGgfqBwQSAggBIgAaCBIGCAMSAggC'" | ||
print(base64_lines_to_textproto(req_base64.strip())) | ||
req_binary = base64.b64decode(req_base64) | ||
req_proto = proto.Request() | ||
req_proto.ParseFromString(req_binary) | ||
tcm_request = proto.TcmRequest(request=req_proto, sequence_id=1) | ||
|
||
ans_proto = tcm.request(tcm_request) | ||
ans = MessageToDict(ans_proto) | ||
|
||
|
||
print(json.dumps(ans, indent=2)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters