#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
User-defined table function related classes and functions
"""
import pickle
import sys
import warnings
from typing import Any, Type, TYPE_CHECKING, Optional, Union
from py4j.java_gateway import JavaObject
from pyspark.errors import PySparkAttributeError, PySparkRuntimeError, PySparkTypeError
from pyspark.rdd import PythonEvalType
from pyspark.sql.column import _to_java_column, _to_seq
from pyspark.sql.pandas.utils import require_minimum_pandas_version, require_minimum_pyarrow_version
from pyspark.sql.types import StructType, _parse_datatype_string
from pyspark.sql.udf import _wrap_function
if TYPE_CHECKING:
from pyspark.sql._typing import ColumnOrName
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
__all__ = ["UDTFRegistration"]
def _create_udtf(
cls: Type,
returnType: Union[StructType, str],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = False,
) -> "UserDefinedTableFunction":
"""Create a Python UDTF with the given eval type."""
udtf_obj = UserDefinedTableFunction(
cls, returnType=returnType, name=name, evalType=evalType, deterministic=deterministic
)
return udtf_obj
def _create_py_udtf(
cls: Type,
returnType: Union[StructType, str],
name: Optional[str] = None,
deterministic: bool = False,
useArrow: Optional[bool] = None,
) -> "UserDefinedTableFunction":
"""Create a regular or an Arrow-optimized Python UDTF."""
# Determine whether to create Arrow-optimized UDTFs.
if useArrow is not None:
arrow_enabled = useArrow
else:
from pyspark.sql import SparkSession
session = SparkSession._instantiatedSession
arrow_enabled = False
if session is not None:
value = session.conf.get("spark.sql.execution.pythonUDTF.arrow.enabled")
if isinstance(value, str) and value.lower() == "true":
arrow_enabled = True
eval_type: int = PythonEvalType.SQL_TABLE_UDF
if arrow_enabled:
# Return the regular UDTF if the required dependencies are not satisfied.
try:
require_minimum_pandas_version()
require_minimum_pyarrow_version()
eval_type = PythonEvalType.SQL_ARROW_TABLE_UDF
except ImportError as e:
warnings.warn(
f"Arrow optimization for Python UDTFs cannot be enabled: {str(e)}. "
f"Falling back to using regular Python UDTFs.",
UserWarning,
)
return _create_udtf(
cls=cls,
returnType=returnType,
name=name,
evalType=eval_type,
deterministic=deterministic,
)
def _validate_udtf_handler(cls: Any) -> None:
"""Validate the handler class of a UDTF."""
if not isinstance(cls, type):
raise PySparkTypeError(
error_class="INVALID_UDTF_HANDLER_TYPE", message_parameters={"type": type(cls).__name__}
)
if not hasattr(cls, "eval"):
raise PySparkAttributeError(
error_class="INVALID_UDTF_NO_EVAL", message_parameters={"name": cls.__name__}
)
[docs]class UserDefinedTableFunction:
"""
User-defined table function in Python
.. versionadded:: 3.5.0
Notes
-----
The constructor of this class is not supposed to be directly called.
Use :meth:`pyspark.sql.functions.udtf` to create this instance.
This API is evolving.
"""
def __init__(
self,
func: Type,
returnType: Union[StructType, str],
name: Optional[str] = None,
evalType: int = PythonEvalType.SQL_TABLE_UDF,
deterministic: bool = False,
):
_validate_udtf_handler(func)
self.func = func
self._returnType = returnType
self._returnType_placeholder: Optional[StructType] = None
self._inputTypes_placeholder = None
self._judtf_placeholder = None
self._name = name or func.__name__
self.evalType = evalType
self.deterministic = deterministic
@property
def returnType(self) -> StructType:
# `_parse_datatype_string` accesses to JVM for parsing a DDL formatted string.
# This makes sure this is called after SparkContext is initialized.
if self._returnType_placeholder is None:
if isinstance(self._returnType, str):
parsed = _parse_datatype_string(self._returnType)
else:
parsed = self._returnType
if not isinstance(parsed, StructType):
raise PySparkTypeError(
error_class="UDTF_RETURN_TYPE_MISMATCH",
message_parameters={
"name": self._name,
"return_type": f"{parsed}",
},
)
self._returnType_placeholder = parsed
return self._returnType_placeholder
@property
def _judtf(self) -> JavaObject:
if self._judtf_placeholder is None:
self._judtf_placeholder = self._create_judtf(self.func)
return self._judtf_placeholder
def _create_judtf(self, func: Type) -> JavaObject:
from pyspark.sql import SparkSession
spark = SparkSession._getActiveSessionOrCreate()
sc = spark.sparkContext
try:
wrapped_func = _wrap_function(sc, func)
except pickle.PicklingError as e:
if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
raise PySparkRuntimeError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
"message": "it appears that you are attempting to reference SparkSession "
"inside a UDTF. SparkSession can only be used on the driver, "
"not in code that runs on workers. Please remove the reference "
"and try again.",
},
) from None
raise PySparkRuntimeError(
error_class="UDTF_SERIALIZATION_ERROR",
message_parameters={
"name": self._name,
"message": "Please check the stack trace and make sure the "
"function is serializable.",
},
)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
assert sc._jvm is not None
judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(
self._name, wrapped_func, jdt, self.evalType, self.deterministic
)
return judtf
def __call__(self, *cols: "ColumnOrName") -> "DataFrame":
from pyspark.sql import DataFrame, SparkSession
spark = SparkSession._getActiveSessionOrCreate()
sc = spark.sparkContext
judtf = self._judtf
jPythonUDTF = judtf.apply(spark._jsparkSession, _to_seq(sc, cols, _to_java_column))
return DataFrame(jPythonUDTF, spark)
[docs] def asDeterministic(self) -> "UserDefinedTableFunction":
"""
Updates UserDefinedTableFunction to deterministic.
"""
# Explicitly clean the cache to create a JVM UDTF instance.
self._judtf_placeholder = None
self.deterministic = True
return self
[docs]class UDTFRegistration:
"""
Wrapper for user-defined table function registration. This instance can be accessed by
:attr:`spark.udtf` or :attr:`sqlContext.udtf`.
.. versionadded:: 3.5.0
"""
def __init__(self, sparkSession: "SparkSession"):
self.sparkSession = sparkSession
[docs] def register(
self,
name: str,
f: "UserDefinedTableFunction",
) -> "UserDefinedTableFunction":
"""Register a Python user-defined table function as a SQL table function.
.. versionadded:: 3.5.0
Parameters
----------
name : str
The name of the user-defined table function in SQL statements.
f : function or :meth:`pyspark.sql.functions.udtf`
The user-defined table function.
Returns
-------
function
The registered user-defined table function.
Notes
-----
Spark uses the return type of the given user-defined table function as the return
type of the registered user-defined function.
To register a nondeterministic Python table function, users need to first build
a nondeterministic user-defined table function and then register it as a SQL function.
Examples
--------
>>> from pyspark.sql.functions import udtf
>>> @udtf(returnType="c1: int, c2: int")
... class PlusOne:
... def eval(self, x: int):
... yield x, x + 1
...
>>> _ = spark.udtf.register(name="plus_one", f=PlusOne)
>>> spark.sql("SELECT * FROM plus_one(1)").collect()
[Row(c1=1, c2=2)]
Use it with lateral join
>>> spark.sql("SELECT * FROM VALUES (0, 1), (1, 2) t(x, y), LATERAL plus_one(x)").collect()
[Row(x=0, y=1, c1=0, c2=1), Row(x=1, y=2, c1=1, c2=2)]
"""
if f.evalType not in [PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF]:
raise PySparkTypeError(
error_class="INVALID_UDTF_EVAL_TYPE",
message_parameters={
"name": name,
"eval_type": "SQL_TABLE_UDF, SQL_ARROW_TABLE_UDF",
},
)
register_udtf = _create_udtf(
cls=f.func,
returnType=f.returnType,
name=name,
evalType=f.evalType,
deterministic=f.deterministic,
)
self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)
return register_udtf
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.udf
globs = pyspark.sql.udtf.__dict__.copy()
spark = SparkSession.builder.master("local[4]").appName("sql.udtf tests").getOrCreate()
globs["spark"] = spark
(failure_count, test_count) = doctest.testmod(
pyspark.sql.udtf, globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()