Skip to content

Commit

Permalink
[BSE-4452] Port remaining Python Entry Code to use the Defined Interf…
Browse files Browse the repository at this point in the history
…ace (#115)
  • Loading branch information
njriasan authored Jan 7, 2025
1 parent 9e83b60 commit 7efc721
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 409 deletions.
3 changes: 1 addition & 2 deletions BodoSQL/bodosql/bodosql_types/snowflake_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
raise_bodo_error,
)
from bodosql import DatabaseCatalog, DatabaseCatalogType
from bodosql.imported_java_classes import JavaEntryPoint
from bodosql.py4j_gateway import build_java_properties
from bodosql.imported_java_classes import JavaEntryPoint, build_java_properties


def _validate_constructor_args(
Expand Down
125 changes: 48 additions & 77 deletions BodoSQL/bodosql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
import warnings
from enum import Enum, IntEnum
from enum import Enum
from typing import Any

import numba
Expand All @@ -22,11 +22,10 @@
from bodosql.bodosql_types.database_catalog import DatabaseCatalog
from bodosql.bodosql_types.table_path import TablePath, TablePathType
from bodosql.imported_java_classes import (
ColumnDataTypeClass,
JavaEntryPoint,
RelationalAlgebraGeneratorClass,
build_java_array_list,
build_java_hash_map,
)
from bodosql.py4j_gateway import build_java_array_list, build_java_hash_map
from bodosql.utils import BodoSQLWarning, error_to_string

# Prefix to add to table argument names when passed to JIT to avoid variable name conflicts
Expand Down Expand Up @@ -123,17 +122,6 @@ class SqlTypeEnum(Enum):
}


# Hacky way to get the planner type option to Java.
# I don't want to access the Java enum class or the constants
# defined in Java that are used for this decision from Python
# so we're going to redefine the enum here.
#
# Not intended as a public API.
class _PlannerType(IntEnum):
Volcano = 0
Streaming = 1


def construct_tz_aware_array_type(typ, nullable):
"""Construct a BodoSQL data type for a tz-aware timestamp array
Expand All @@ -151,12 +139,12 @@ def construct_tz_aware_array_type(typ, nullable):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Timestamp_Ntz.value
)
return ColumnDataTypeClass(type_enum, nullable, precision)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable, precision)
else:
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Timestamp_Ltz.value
)
return ColumnDataTypeClass(type_enum, nullable, precision)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable, precision)


def construct_time_array_type(typ: bodo.TimeArrayType | bodo.TimeType, nullable: bool):
Expand All @@ -172,7 +160,7 @@ def construct_time_array_type(typ: bodo.TimeArrayType | bodo.TimeType, nullable:
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Time.value
)
return ColumnDataTypeClass(type_enum, nullable, typ.precision)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable, typ.precision)


def construct_array_item_array_type(arr_type):
Expand All @@ -190,7 +178,7 @@ def construct_array_item_array_type(arr_type):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Array.value
)
return ColumnDataTypeClass(type_enum, True, child)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, True, child)


def construct_json_array_type(arr_type):
Expand All @@ -210,23 +198,23 @@ def construct_json_array_type(arr_type):
key_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.String.value
)
key = ColumnDataTypeClass(key_enum, True)
key = JavaEntryPoint.buildColumnDataTypeInfo(key_enum, True)
value_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Variant.value
)
value = ColumnDataTypeClass(value_enum, True)
value = JavaEntryPoint.buildColumnDataTypeInfo(value_enum, True)
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Json_Object.value
)
return ColumnDataTypeClass(type_enum, True, key, value)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, True, key, value)
else:
# TODO: Add map scalar support
key = get_sql_data_type(arr_type.key_arr_type)
value = get_sql_data_type(arr_type.value_arr_type)
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Json_Object.value
)
return ColumnDataTypeClass(type_enum, True, key, value)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, True, key, value)


def get_sql_column_type(arr_type, col_name):
Expand All @@ -251,15 +239,15 @@ def get_sql_data_type(arr_type):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Timestamp_Tz.value
)
return ColumnDataTypeClass(type_enum, nullable)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable)
elif isinstance(arr_type, bodo.TimeArrayType):
# Time array types have their own special handling for precision
return construct_time_array_type(arr_type, nullable)
elif isinstance(arr_type, bodo.DecimalArrayType):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Decimal.value
)
return ColumnDataTypeClass(
return JavaEntryPoint.buildColumnDataTypeInfo(
type_enum, nullable, arr_type.precision, arr_type.scale
)
elif isinstance(arr_type, bodo.ArrayItemArrayType):
Expand All @@ -270,13 +258,13 @@ def get_sql_data_type(arr_type):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
_numba_to_sql_column_type_map[arr_type.dtype]
)
return ColumnDataTypeClass(type_enum, nullable)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable)
elif isinstance(arr_type.dtype, bodo.PDCategoricalDtype):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Categorical.value
)
child = get_sql_data_type(dtype_to_array_type(arr_type.dtype.elem_type, True))
return ColumnDataTypeClass(type_enum, nullable, child)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable, child)
else:
# The type is unsupported we raise a warning indicating this is a possible
# error but we generate a dummy type because we may be able to support it
Expand All @@ -285,7 +273,7 @@ def get_sql_data_type(arr_type):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Unsupported.value
)
return ColumnDataTypeClass(type_enum, nullable)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable)


def create_java_dynamic_parameter_type_list(dynamic_params_list: list[Any]):
Expand Down Expand Up @@ -333,7 +321,7 @@ def get_sql_param_column_type_info(param_type: types.Type):
Args:
param_type (types.Type): The bodo type to lower as a parameter.
Return:
JavaObject: The ColumnDataTypeClass for the parameter type.
JavaObject: The ColumnDataTypeInfo for the parameter type.
"""
unliteral_type = types.unliteral(param_type)
# The named parameters are always scalars. We don't support
Expand All @@ -353,14 +341,14 @@ def get_sql_param_column_type_info(param_type: types.Type):
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
SqlTypeEnum.Decimal.value
)
return ColumnDataTypeClass(
return JavaEntryPoint.buildColumnDataTypeInfo(
type_enum, nullable, unliteral_type.precision, unliteral_type.scale
)
elif unliteral_type in _numba_to_sql_param_type_map:
type_enum = JavaEntryPoint.buildBodoSQLColumnDataTypeFromTypeId(
_numba_to_sql_param_type_map[unliteral_type]
)
return ColumnDataTypeClass(type_enum, nullable)
return JavaEntryPoint.buildColumnDataTypeInfo(type_enum, nullable)
raise TypeError(
f"Dynamic Parameter with type {param_type} not supported in BodoSQL. Please cast your data to a supported type. https://docs.bodo.ai/latest/source/BodoSQL.html#supported-data-types"
)
Expand Down Expand Up @@ -579,7 +567,7 @@ def add_table_type(

table = JavaEntryPoint.buildLocalTable(
table_name,
schema.getFullPath(),
schema,
col_arr,
is_writeable,
read_code,
Expand All @@ -591,7 +579,7 @@ def add_table_type(
estimated_row_count,
estimated_ndvs_java_map,
)
schema.addTable(table)
JavaEntryPoint.addTableToSchema(schema, table)


def _get_estimated_row_count(table: pd.DataFrame | TablePath) -> int | None:
Expand Down Expand Up @@ -963,20 +951,19 @@ def _convert_to_pandas(
sql: str,
dynamic_params_list: list[Any],
named_params_dict: dict[str, Any],
generator: RelationalAlgebraGeneratorClass,
generator,
is_ddl: bool,
) -> tuple[str, dict[str, Any]]:
"""Generate the func_text for the Python code generated for the given SQL query.
This is always computed entirely on rank 0 to avoid parallelism errors.
Args:
sql (str): The SQL query to process.
optimize_plan (bool): Should the generated plan be optimized?
params_dict (Dict[str, Any]): A python dictionary mapping Python variables
to usable SQL names that can be referenced in the query.
hide_credentials (bool): Should credentials be hidden in the generated code. This
is set to true when we want to inspect the code but not run the code.
dynamic_params_list (List[Any]): The list of dynamic parameters to lower.
named_params_dict (Dict[str, Any]): The named parameters to lower.
generator (RelationalAlgebraGenerator Java Object): The relational algebra generator
used to generate the code.
is_ddl (bool): Is this a DDL query?
Raises:
BodoError: If the SQL query cannot be processed.
Expand Down Expand Up @@ -1185,15 +1172,15 @@ def generate_plan(
def _get_pandas_code(
self,
sql: str,
generator: RelationalAlgebraGeneratorClass,
generator,
dynamic_params_list: list[Any],
named_params_dict: dict[str, Any],
) -> tuple[str, dict[str, Any]]:
"""Generate the Pandas code for the given SQL string.
Args:
sql (str): The SQL query text.
generator (RelationalAlgebraGeneratorClass): The relational algebra generator
generator (RelationalAlgebraGenerator Java Object): The relational algebra generator
used to generate the code.
Raises:
Expand Down Expand Up @@ -1224,7 +1211,7 @@ def _get_pandas_code(
raise bodo.utils.typing.BodoError(
f"Unable to compile SQL Query. Error message:\n{message}"
)
return pd_code, generator.getLoweredGlobalVariables()
return pd_code, JavaEntryPoint.getLoweredGlobals(generator)

def _create_generator(self, hide_credentials: bool):
"""Creates a RelationalAlgebraGenerator from the schema.
Expand All @@ -1234,37 +1221,19 @@ def _create_generator(self, hide_credentials: bool):
any generated code.
Returns:
RelationalAlgebraGeneratorClass: The java object holding
RelationalAlgebraGenerator Java Object: The java object holding
the relational algebra generator.
"""
verbose_level = bodo.user_logging.get_verbose_level()
tracing_level = bodo.tracing_level
if bodo.bodosql_use_streaming_plan:
planner_type = _PlannerType.Streaming.value
else:
planner_type = _PlannerType.Volcano.value
if self.catalog is not None:
return RelationalAlgebraGeneratorClass(
self.catalog.get_java_object(),
self.schema,
planner_type,
verbose_level,
tracing_level,
bodo.bodosql_streaming_batch_size,
hide_credentials,
bodo.enable_snowflake_iceberg,
bodo.enable_timestamp_tz,
bodo.enable_runtime_join_filters,
bodo.enable_streaming_sort,
bodo.enable_streaming_sort_limit_offset,
bodo.bodo_sql_style,
bodo.bodosql_full_caching,
bodo.prefetch_sf_iceberg,
)
extra_args = () if self.default_tz is None else (self.default_tz,)
generator = RelationalAlgebraGeneratorClass(
catalog_obj = self.catalog.get_java_object()
else:
catalog_obj = None
return JavaEntryPoint.buildRelationalAlgebraGenerator(
catalog_obj,
self.schema,
planner_type,
bodo.bodosql_use_streaming_plan,
verbose_level,
tracing_level,
bodo.bodosql_streaming_batch_size,
Expand All @@ -1277,9 +1246,8 @@ def _create_generator(self, hide_credentials: bool):
bodo.bodo_sql_style,
bodo.bodosql_full_caching,
bodo.prefetch_sf_iceberg,
*extra_args,
self.default_tz,
)
return generator

def add_or_replace_view(self, name: str, table: pd.DataFrame | TablePath):
"""Create a new BodoSQLContext that contains all of the old DataFrames and the
Expand Down Expand Up @@ -1401,9 +1369,7 @@ def __eq__(self, bc: object) -> bool:
return self.catalog == bc.catalog
return False # pragma: no cover

def execute_ddl(
self, sql: str, generator: RelationalAlgebraGeneratorClass | None = None
) -> pd.DataFrame:
def execute_ddl(self, sql: str, generator=None) -> pd.DataFrame:
"""API to directly execute DDL queries. This is used by the JIT
path to execute DDL queries and can be used as a fast path when you
statically know the query you want to execute is a DDL query to avoid the
Expand All @@ -1414,7 +1380,7 @@ def execute_ddl(
Args:
sql (str): The DDL query to execute.
generator (Optional[RelationalAlgebraGeneratorClass]): The prepared planner
generator (Optional[RelationalAlgebraGenerator Java object]): The prepared planner
information used for executing the query. If None we need to create
the planner.
Expand All @@ -1440,14 +1406,19 @@ def execute_ddl(
try:
ddl_result = JavaEntryPoint.executeDDL(generator, sql)
# Convert the output to a DataFrame.
column_names = list(ddl_result.getColumnNames())
column_names = list(
JavaEntryPoint.getDDLExecutionColumnNames(ddl_result)
)
column_types = [
_generate_ddl_column_type(t) for t in ddl_result.getColumnTypes()
_generate_ddl_column_type(t)
for t in JavaEntryPoint.getDDLExecutionColumnTypes(ddl_result)
]
data = [
# Use astype to avoid issues with Java conversion.
pd.array(column, dtype=object).astype(column_types[i])
for i, column in enumerate(ddl_result.getColumnValues())
for i, column in enumerate(
JavaEntryPoint.getDDLColumnValues(ddl_result)
)
]
df_dict = {column_names[i]: data[i] for i in range(len(column_names))}
result = pd.DataFrame(
Expand Down
Loading

0 comments on commit 7efc721

Please sign in to comment.