背景
本文基于Spark 3.5.1
无论是从 spark的官网Arrow Python UDFs,还是databricks的一些python udf,好像都没有说到在Spark SQL中怎么直接调用 python定义的UDF,但是其实在使用上,Spark SQL是可以直接使用 python定义的UDF的,
分享本文的目的就在于 使读者明确 怎么在Spark SQL中调用 python注册的UDF,这里的的SQL 可以不仅仅是在 python api 中调用,也可以是在 java或者scala api中调用的。
调用
我们直接到Spark中的类 SubquerySuite,可以看到如下的例子:
import IntegratedUDFTestUtils._assume(shouldTestPythonUDFs)val pythonTestUDF = TestPythonUDF(name = "udf")registerTestUDF(pythonTestUDF, spark)// Case 1: Canonical example of the COUNT bugcheckAnswer(sql("SELECT l.a FROM l WHERE (SELECT udf(count(*)) FROM r WHERE l.a = r.c) < l.a"),Row(1) :: Row(1) :: Row(3) :: Row(6) :: Nil)
其中 registerTestUDF如下:
def registerTestUDF(testUDF: TestUDF, session: SparkSession): Unit = testUDF match {case udf: TestPythonUDF => session.udf.registerPython(udf.name, udf.udf)case udf: TestScalarPandasUDF => session.udf.registerPython(udf.name, udf.udf)case udf: TestGroupedAggPandasUDF => session.udf.registerPython(udf.name, udf.udf)case udf: TestScalaUDF =>val registry = session.sessionState.functionRegistryregistry.createOrReplaceTempFunction(udf.name, udf.builder, "scala_udf")case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]")}
这里是在scala代码
中的 Spark SQL调用了 python注册的UDF
,而且从单元测试的结果来看,是没有什么异常的,最主要的是 TestPythonUDF
,以下直接分析一下.
分析
怎么注册python udf
直接到TestPythonUDF
代码块:
case class TestPythonUDF(name: String, returnType: Option[DataType] = None) extends TestUDF {private[IntegratedUDFTestUtils] lazy val udf = new UserDefinedPythonFunction(name = name,func = SimplePythonFunction(command = pythonFunc.toImmutableArraySeq,envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],pythonIncludes = List.empty[String].asJava,pythonExec = pythonExec,pythonVer = pythonVer,broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,accumulator = null),dataType = StringType,pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,udfDeterministic = true) {override def builder(e: Seq[Expression]): Expression = {assert(e.length == 1, "Defined UDF only has one column")val expr = e.headval rt = returnType.getOrElse {assert(expr.resolved, "column should be resolved to use the same type " +"as input. Try df(name) or df.col(name)")expr.dataType}val pythonUDF = new PythonUDFWithoutId(super.builder(Cast(expr, StringType) :: Nil).asInstanceOf[PythonUDF])Cast(pythonUDF, rt)}}def apply(exprs: Column*): Column = udf(exprs: _*)val prettyName: String = "Regular Python UDF"}
最主要的是UserDefinedPythonFunction
中SimplePythonFunction
中command
,这个才是Python UDF的核心。,对应的pythonFunc
如下:
private lazy val pythonFunc: Array[Byte] = if (shouldTestPythonUDFs) {var binaryPythonFunc: Array[Byte] = nullwithTempPath { path =>Process(Seq(pythonExec,"-c","from pyspark.sql.types import StringType; " +"from pyspark.serializers import CloudPickleSerializer; " +s"f = open('$path', 'wb');" +"f.write(CloudPickleSerializer().dumps((" +"lambda x: None if x is None else str(x), StringType())))"),None,"PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!binaryPythonFunc = Files.readAllBytes(path.toPath)}assert(binaryPythonFunc != null)binaryPythonFunc} else {throw new RuntimeException(s"Python executable [$pythonExec] and/or pyspark are unavailable.")}
这块代码的逻辑是: 用CloudPickleSerializer 序列化到 文件中,并读取存储在文件中的二进制数组,该二进制数组组成了pythonFunc
. 这个字节数组的comamnd
会在后续中会被反序列化为对应的方法,且被调用。
注意
其实在正常的使用中,比如在 udtf.py 中,会 使用 `py4j` 在python中调用java的方法来调用注册udf么,如下:
class UDTFRegistration:...register_udtf = _create_udtf(cls=f.func,returnType=f.returnType,name=name,evalType=f.evalType,deterministic=f.deterministic,)self.sparkSession._jsparkSession.udtf().registerPython(name, register_udtf._judtf)return register_udtf
其中 udtf() 返回的是 UDTFRegistration;
_judtf 返回的是 UserDefinedPythonTableFunction ,而这里的command是 CloudPickleSerializer 反序列化后的字节数组
调用udf的数据流
就从 UDFRegistration.registerPython 注册这个方法入手,该方法会调用UserDefinedPythonFunction的builder方法
生成PythonUDF
,该 PythonUDF 是不可计算的,所以会经过Rule的转换:
PythonUDF ||\/ 经过 Rule ExtractPythonUDFs
BatchEvalPython/ArrowEvalPython|| \/ 经过 Rule PythonEvals
ArrowEvalPythonExec/BatchEvalPythonExec
目前就拿ArrowEvalPythonExec 举例,ArrowEvalPythonExec 最终会调用EvalPythonEvaluatorFactory.compute
方法:
这里会启动一个worker.py(python -m pyspark.daemo pyspark.worker
),
而worker.py
中:
if eval_type == PythonEvalType.NON_UDF:func, profiler, deserializer, serializer = read_command(pickleSer, infile)elif eval_type in (PythonEvalType.SQL_TABLE_UDF, PythonEvalType.SQL_ARROW_TABLE_UDF):func, profiler, deserializer, serializer = read_udtf(pickleSer, infile, eval_type)else:func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)init_time = time.time()def process():iterator = deserializer.load_stream(infile)out_iter = func(split_index, iterator)try:serializer.dump_stream(out_iter, outfile)finally:if hasattr(out_iter, "close"):out_iter.close()if profiler:profiler.profile(process)else:process()
其中 func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
以及out_iter = func(split_index, iterator)
就是用来反序列udf函数,以及用来处理数据的,
read_udfs 比较关键的代码为:
arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0, profiler=profiler)def func(_, iterator):num_input_rows = 0def map_batch(batch):nonlocal num_input_rowsudf_args = [batch[offset] for offset in arg_offsets]num_input_rows += len(udf_args[0])if len(udf_args) == 1:return udf_args[0]else:return tuple(udf_args)iterator = map(map_batch, iterator)result_iter = udf(iterator)
可以看到 这种运行 python UDF的方式是以socket的方式进行交互的,所以这种方式相对来说还是会比较慢的。