Skip to content

Commit

Permalink
add pandas types
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-yuwang committed Nov 19, 2024
1 parent c686b56 commit 798f8a9
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 2 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
- `simple_string`: Provides a simple string representation of the data.
- `json_value`: Returns the data as a JSON-compatible value.
- `json`: Converts the data to a JSON string.
- To `ArrayType`, `MapType`, `StructField`, and `StructType`:
- To `ArrayType`, `MapType`, `StructField`, `PandasSeriesType`, `PandasDataFrameType` and `StructType`:
- `from_json`: Enables these types to be created from JSON data.
- To `MapType`:
- `keyType`: keys of the map
Expand Down
68 changes: 68 additions & 0 deletions src/snowflake/snowpark/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,34 @@ class PandasSeriesType(_PandasType):
def __init__(self, element_type: Optional[DataType]) -> None:
self.element_type = element_type

def __repr__(self) -> str:
return (
f"PandasSeriesType({repr(self.element_type) if self.element_type else ''})"
)

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "PandasSeriesType":
return PandasSeriesType(
_parse_datatype_json_value(json_dict["element_type"])
if json_dict["element_type"]
else None
)

def simple_string(self) -> str:
return f"pandasseries<{self.element_type.simple_string() if self.element_type else ''}>"

def json_value(self) -> Dict[str, Any]:
return {
"type": self.type_name(),
"element_type": self.element_type.json_value()
if self.element_type
else None,
}

simpleString = simple_string
jsonValue = json_value
fromJson = from_json


class PandasDataFrameType(_PandasType):
"""
Expand All @@ -657,13 +685,52 @@ def __init__(
self.col_types = col_types
self.col_names = col_names or []

def __repr__(self) -> str:
col_names = f", [{', '.join(self.col_names)}]" if self.col_names != [] else ""
return f"PandasDataFrameType([{', '.join([repr(col) for col in self.col_types])}]{col_names})"

def get_snowflake_col_datatypes(self):
"""Get the column types of the dataframe as the input/output of a vectorized UDTF."""
return [
tp.element_type if isinstance(tp, PandasSeriesType) else tp
for tp in self.col_types
]

@classmethod
def from_json(cls, json_dict: Dict[str, Any]) -> "PandasDataFrameType":
temp_col_names = []
temp_col_types = []
for cols in json_dict["fields"]:
if cols["name"] != "":
temp_col_names.append(cols["name"])
temp_col_types.append(_parse_datatype_json_value(cols["type"]))
return PandasDataFrameType(temp_col_types, temp_col_names)

def simple_string(self) -> str:
return f"pandas<{','.join(f.simple_string() for f in self.col_types)}>"

def json_value(self) -> Dict[str, Any]:
temp_col_name = (
self.col_names
if self.col_names != []
else ["" for _ in range(len(list(self.col_types)))]
)

return {
"type": self.type_name(),
"fields": [
self._json_value_helper(n, t)
for (n, t) in zip(temp_col_name, self.col_types)
],
}

def _json_value_helper(self, col_name, col_type) -> Dict[str, Any]:
return {"name": col_name, "type": col_type.json_value()}

simpleString = simple_string
jsonValue = json_value
fromJson = from_json


_atomic_types: List[Type[DataType]] = [
StringType,
Expand All @@ -686,6 +753,7 @@ def get_snowflake_col_datatypes(self):
ArrayType,
MapType,
StructType,
PandasDataFrameType,
]
_all_complex_types: Dict[str, Type[Union[ArrayType, MapType, StructType]]] = {
v.typeName(): v for v in _complex_types
Expand Down
70 changes: 69 additions & 1 deletion tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,61 @@ def test_snow_type_to_dtype_str():
"vector",
"vector(float,8)",
),
(
PandasDataFrameType(
[StringType(), IntegerType(), FloatType()], ["id", "col1", "col2"]
),
"pandas<string,int,float>",
'{"fields":[{"name":"id","type":"string"},{"name":"col1","type":"integer"},{"name":"col2","type":"float"}],"type":"pandasdataframe"}',
"pandasdataframe",
{
"type": "pandasdataframe",
"fields": [
{"name": "id", "type": "string"},
{"name": "col1", "type": "integer"},
{"name": "col2", "type": "float"},
],
},
),
(
PandasDataFrameType(
[ArrayType(ArrayType(IntegerType())), IntegerType(), FloatType()]
),
"pandas<array<array<int>>,int,float>",
'{"fields":[{"name":"","type":{"element_type":{"element_type":"integer","type":"array"},"type":"array"}},{"name":"","type":"integer"},{"name":"","type":"float"}],"type":"pandasdataframe"}',
"pandasdataframe",
{
"type": "pandasdataframe",
"fields": [
{
"name": "",
"type": {
"type": "array",
"element_type": {
"type": "array",
"element_type": "integer",
},
},
},
{"name": "", "type": "integer"},
{"name": "", "type": "float"},
],
},
),
(
PandasSeriesType(IntegerType()),
"pandasseries<int>",
'{"element_type":"integer","type":"pandasseries"}',
"pandasseries",
{"type": "pandasseries", "element_type": "integer"},
),
(
PandasSeriesType(None),
"pandasseries<>",
'{"element_type":null,"type":"pandasseries"}',
"pandasseries",
{"type": "pandasseries", "element_type": None},
),
],
)
def test_datatype(tpe, simple_string, json, type_name, json_value):
Expand Down Expand Up @@ -1303,6 +1358,20 @@ def test_datatype(tpe, simple_string, json, type_name, json_value):
StructField,
StructField("AA", DecimalType(20, 10)),
),
(
PandasDataFrameType,
PandasDataFrameType(
[StringType(), IntegerType(), FloatType()], ["id", "col1", "col2"]
),
),
(
PandasDataFrameType,
PandasDataFrameType(
[ArrayType(ArrayType(IntegerType())), IntegerType(), FloatType()]
),
),
(PandasSeriesType, PandasSeriesType(IntegerType())),
(PandasSeriesType, PandasSeriesType(None)),
],
)
def test_structtype_from_json(datatype, tpe):
Expand Down Expand Up @@ -1342,4 +1411,3 @@ def test_maptype_alias():

assert tpe.valueType == tpe.value_type
assert tpe.keyType == tpe.key_type

0 comments on commit 798f8a9

Please sign in to comment.