Skip to content

Commit

Permalink
tcm poc
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-azhan committed Nov 14, 2024
1 parent 602010d commit 1a292e3
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 9 deletions.
85 changes: 85 additions & 0 deletions src/snowflake/snowpark/_internal/proto/SnowparkTcm.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
syntax = "proto3";

option cc_generic_services = false;
option java_outer_classname = "SnowparkTcmProto";
option java_package = "com.snowflake.snowpark.proto.tcm";

import "google/protobuf/descriptor.proto";

package tcm;

enum TcmPayloadType {
SNOWPARK_PYTHON_AST = 0;
// ... future extensions, e.g., Spark API, pandas API, etc.
}

// The TCM request
message TcmRequest {
// The payload of the request, e.g., the Dataframe AST request.
bytes payload = 1;
// The type of the payload
TcmPayloadType type = 2;
// ID of the request
int64 request_id = 3;
// ... future extensions, e.g., telemetry, metrics config.
}

// Result for the TCM request
message TcmResponse {
// The types of TCM errors
enum ErrorCode {
// TCM Failed to initialize
TCM_INIT_ERROR = 0;
// TCM Failed to execute the request
TCM_EXECUTION_ERROR = 1;
// ... other error codes
}

// Indicates that the request has failed
message Error {
ErrorCode code = 1;
string message = 2;
}

// The null value
enum NullValue {
NULL_VALUE = 0;
}

// The non-query result
message OkResult {
oneof value {
// The query UUID in string format
string query_uuid = 1;
// The string value
string val_str = 2;
// The int64 value
int64 val_int = 3;
// The double value
double val_double = 4;
// The boolean value
bool val_bool = 5;
// The null value
NullValue val_null = 6;
// ... future extensions, e.g., map, struct, etc.
}
}

// The payload of the response, e.g., the Dataframe AST response.
bytes payload = 1;

// The type of the payload
TcmPayloadType type = 2;

// ID of the request with which this response is associated.
int64 request_id = 3;

// The result of the TCM request. It is either an error status, a query uuid, or a non-query result.
oneof result {
// The error status
Error error = 4;
// The okay result
OkResult okResult = 5;
// ... future extensions, e.g., telemetry, metrics, log, etc.
}
}
20 changes: 11 additions & 9 deletions src/snowflake/snowpark/_internal/proto/ast.proto
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,7 @@ message SpWindowRelativePosition {

// expr-window.ir:10
message SpWindowRelativePosition_Position {
<<<<<<< HEAD
Expr n = 1;
=======
int64 n = 1;
>>>>>>> 117c71b59 (cherry-pick #2549)
}

// pd-indexing.ir:2
Expand Down Expand Up @@ -1478,10 +1474,11 @@ message EvalResult {
PythonDateVal python_date_val = 14;
PythonTimeVal python_time_val = 15;
PythonTimestampVal python_timestamp_val = 16;
SpDatatypeVal sp_datatype_val = 17;
StringVal string_val = 18;
TimeVal time_val = 19;
TimestampVal timestamp_val = 20;
SfQueryResult sf_query_result = 17;
SpDatatypeVal sp_datatype_val = 18;
StringVal string_val = 19;
TimeVal time_val = 20;
TimestampVal timestamp_val = 21;
}
}

Expand All @@ -1504,6 +1501,11 @@ message SessionResetRequiredError {
VarId var_id = 2;
}

// result.ir:23
message SfQueryResult {
string uuid = 1;
}

message SpColumnExpr {
oneof variant {
SpColumnCaseWhen sp_column_case_when = 1;
Expand Down Expand Up @@ -3277,4 +3279,4 @@ message MapType {
// type.ir:56
message TyVar {
string id = 1;
}
}
68 changes: 68 additions & 0 deletions src/snowflake/snowpark/_internal/tcm/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json
import logging
import snowflake.snowpark
import snowflake.snowpark._internal.proto.generated.ast_pb2 as ast
import snowflake.snowpark._internal.proto.generated.SnowparkTcm_pb2 as tcm_proto
from snowflake.snowpark._internal.tcm.sp_ast_decoder import SnowparkAstDecoder
import base64
from google.protobuf import message

logger = logging.getLogger(__name__)


def str2proto(b64_input: str, proto_output: message.Message) -> None:
decoded = base64.b64decode(b64_input)
proto_output.ParseFromString(decoded)


def proto2str(proto_input: message.Message) -> str:
return str(base64.b64encode(proto_input.SerializeToString()), "utf-8")


class TcmSession:
def __init__(self, session: snowflake.snowpark.Session):
"""
Initializes TCM with optional Snowpark session to connect to.
Args:
session: optional, if None automatically retrieves parameters.
"""
self._session = session
self._decoder = None

def get_decoder(self, type):
if type != tcm_proto.SNOWPARK_PYTHON_AST:
raise NotImplementedError
if self._decoder is None:
self._decoder = SnowparkAstDecoder(self._session)
return self._decoder

def construct_tcm_response_from_ast(self, ast_res_proto, type, rid):
uuid = ast_res_proto.body[0].eval_ok.data.sf_query_result.uuid

return tcm_proto.TcmResponse(payload=ast_res_proto.SerializeToString(), type=type,
request_id=rid,
okResult=tcm_proto.TcmResponse.OkResult(query_uuid=uuid))

def request(self, tcm_req_base64: str) -> str:
try:
tcm_req_proto = tcm_proto.TcmRequest()
str2proto(tcm_req_base64, tcm_req_proto)
rid = tcm_req_proto.request_id
logging.debug(f"request id: {rid}")

assert tcm_req_proto.type == tcm_proto.SNOWPARK_PYTHON_AST

decoder = self.get_decoder(tcm_req_proto.type)

ast_req_proto = ast.Request()
ast_req_proto.ParseFromString(tcm_req_proto.payload)

ast_res_proto = decoder.request(ast_req_proto)

tcm_res_proto = self.construct_tcm_response_from_ast(ast_res_proto, tcm_req_proto.type, rid)
return proto2str(tcm_res_proto)
except Exception:
raise



88 changes: 88 additions & 0 deletions src/snowflake/snowpark/_internal/tcm/sp_ast_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from typing import Any

import snowflake
import logging
import snowflake.snowpark._internal.proto.generated.ast_pb2 as ast
logger = logging.getLogger(__name__)


def column_expr(e):
variant = e.WhichOneof('variant')
if variant == 'sp_column_sql_expr':
e = e.sp_column_sql_expr
return e.sql
else:
logger.warning('Unexpected column expr %s', str(e))
return None


class SnowparkAstDecoder:
def __init__(self, session: snowflake.snowpark.Session):
self._session = session
self._session.ast_enabled = True
self.bindings = {}

def request(self, req: ast.Request) -> ast.Response:
resp = ast.Response()

for stmt in req.body:
variant = stmt.WhichOneof('variant')
if variant == 'assign':
self.assign(stmt.assign)
elif variant == 'eval':
ans = self.eval(stmt.eval)

# Fill into response evaluation result, this allows on the client-side to reconstruct
# values with results.
ok_result = ast.EvalOk(uid=stmt.eval.uid, var_id=stmt.eval.var_id, data=ans)
res = ast.Result()
res.eval_ok.CopyFrom(ok_result)
resp.body.add().CopyFrom(res)
else:
logger.warning('Unexpected statement %s', str(stmt))
logger.info('Session bindings %s', str(self.bindings))
return resp

def get_binding(self, var_id):
# TODO: check if valid.
return self.bindings[var_id.bitfield1]

def assign(self, assign) -> None:
val = self.expr(assign.expr)
self.bindings[assign.var_id.bitfield1] = val

def eval(self, eval) -> Any:
res = self.get_binding(eval.var_id)
logger.info('Return atom %s := %s', eval.var_id.bitfield1, str(res))
return res

def expr(self, e):
variant = e.WhichOneof('variant')
if variant == 'sp_dataframe_ref':
e = e.sp_dataframe_ref
return self.get_binding(e.id)
elif variant == 'sp_sql':
e = e.sp_sql
return self._session.sql(e.query)
elif variant == 'sp_table':
e = e.sp_table
return self._session.table(e.table)
elif variant == 'sp_dataframe_filter':
e = e.sp_dataframe_filter
df = self.expr(e.df)
condition = column_expr(e.condition)
return df.filter(condition)
elif variant == 'sp_dataframe_show':
e = e.sp_dataframe_show
df = self.get_binding(e.id)
job = df.collect_nowait()
return job.query_id
elif variant == 'sp_dataframe_collect':
e = e.sp_dataframe_collect
df = self.get_binding(e.id)
job = df.collect_nowait()
return ast.EvalResult(sf_query_result=ast.SfQueryResult(uuid=job.query_id))
else:
logger.warning('Unexpected expr %s', str(e))
return None

84 changes: 84 additions & 0 deletions tests/integ/tcm/test_tcm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Test in this file the TCM.
# If run on workspace/devvm with parameters configured, this should internally call a running GS/XP instance.
from snowflake.snowpark._internal.ast_utils import base64_lines_to_textproto
from snowflake.snowpark._internal.tcm.session import TcmSession, str2proto, proto2str
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
import snowflake.snowpark._internal.proto.generated.SnowparkTcm_pb2 as tcm_proto
from google.protobuf.json_format import MessageToDict
import base64
import json

from snowflake.snowpark.functions import col


def test_ast_gen(session):
tcm = TcmSession(session)
session.ast_enabled = True
# Reset the entity ID generator.
session._ast_batch.reset_id_gen()

session._ast_batch.flush() # Clear the AST.
# Run the test.
with session.ast_listener() as al:
print(session.sql("select ln(3)").collect())

# Retrieve the ASTs corresponding to the test.
df_ast = al.base64_batches
# if last_batch:
# result.append(last_batch)

df_ast = "\n".join(df_ast)
print(df_ast)
print(f"len:{len(df_ast)}")

# GS will receive df_ast

# GS build tcm request
ast_binary = base64.b64decode(df_ast)
print(f"bin len: {len(ast_binary)}")

print(base64_lines_to_textproto(df_ast.strip()))
#
#
# req_proto = proto.Request()
# req_proto.ParseFromString(req_binary)
tcm_request = tcm_proto.TcmRequest(payload=ast_binary, type=tcm_proto.SNOWPARK_PYTHON_AST, request_id=1)
tcm_req_str = str(base64.b64encode(tcm_request.SerializeToString()), "utf-8")
print("tcm request:")

print(json.dumps(MessageToDict(tcm_request), indent=2))
print("tcm request str:")
print(tcm_req_str)
print("tcm request str end:")
tcm_res_str = tcm.request(tcm_req_str)
print(tcm_res_str)
tcm_res = tcm_proto.TcmResponse()

str2proto(tcm_res_str, tcm_res)

ans = MessageToDict(tcm_res)

print(json.dumps(ans, indent=2))

ast_res = proto.Response()
ast_res.ParseFromString(tcm_res.payload)
ans = MessageToDict(ast_res)

print(json.dumps(ans, indent=2))


def test_tcm_eval():
tcm = TcmSession()

req_base64 = "'CgIIKhIICgYKBAgDEAgacwpxCAESAggBGmOKCmAKUgpOL1VzZXJzL2xzcGllZ2VsYmVyZy9wcm9qZWN0cy9zbm93cGFyay1weXRob24vdGVzdHMvdGhpbi1jbGllbnQvc3RlZWwtdGhyZWFkLnB5EDciCnRlc3RfdGFibGUiBAoCZGYaEwoRCAISAggCGgfqBwQSAggBIgAaCBIGCAMSAggC'"
print(base64_lines_to_textproto(req_base64.strip()))
req_binary = base64.b64decode(req_base64)
req_proto = proto.Request()
req_proto.ParseFromString(req_binary)
tcm_request = proto.TcmRequest(request=req_proto, sequence_id=1)

ans_proto = tcm.request(tcm_request)
ans = MessageToDict(ans_proto)


print(json.dumps(ans, indent=2))
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ allowlist_externals = bash, protoc
deps = protobuf
commands =
protoc --proto_path=src/snowflake/snowpark/_internal/proto/ --python_out=src/snowflake/snowpark/_internal/proto/generated --pyi_out=src/snowflake/snowpark/_internal/proto/generated/ src/snowflake/snowpark/_internal/proto/ast.proto
protoc --proto_path=src/snowflake/snowpark/_internal/proto/ --python_out=src/snowflake/snowpark/_internal/proto/generated --pyi_out=src/snowflake/snowpark/_internal/proto/generated/ src/snowflake/snowpark/_internal/proto/SnowparkTcm.proto

[testenv:dev]
description = create dev environment
Expand Down

0 comments on commit 1a292e3

Please sign in to comment.