Passing large dataframes with dbutils.notebook.run !
At one point when migrating databricks notebooks to be useable purely with dbutils.notebook.run
, the question came up, hey dbutils.notebook.run
is a great way of calling notebooks explicitly, avoiding global variables that make code difficult to lint and debug, but what about spark dataframes?
I had come across this https://docs.databricks.com/notebooks/notebook-workflows.html#pass-structured-data nice bit of documentation about using the spark global temp view to handle name references to nicely shuttle around dataframes by reference, given that a caller notebook and a callee notebook share a JVM and theoretically this is instantaneous.
However the example code was a bit lacking and I ended up writing some nice helper functions to make the passing of dataframes, alongside other parameters, a little bit easier and more intuitive!
One of the issues I had with the toy example was that it used static names to pass dataframes, like my_data
. This was clearly just an example, but I wanted a higher gurantee in avoiding weird collisions, so I used uuid
to help randomize the names.
But also I wanted to be able to nicely debug my new view names, so I wanted to mix the random uuid
names with informative names too. It is not straightforward to programmatically capture the name of a variable as a string (you can go down a rabbit hole trying to figure this out haha) so I just settled on creating a simple function prepare_arguments
which takes keyword arguments and uses them as the string names of dataframes, when creating random temp view names.
I also wanted the flexibility of just mixing and matching the plain parmeters you pass in, along with dataframes, without making a big deal about it.
Here is where I ended up below :)
import json
import pandas as pd
from uuid import uuid4
from pyspark.sql import SparkSession
def prepare_arguments(**kwargs):
"""Create the dbutils.notebook.run payload and put dataframes into global_temp."""
input_dataframes = {k: v for (k, v) in kwargs.items() if isinstance(v, pd.DataFrame)}
the_rest = {k: v for (k, v) in kwargs.items() if k not in input_dataframes}
dataframes_dict = prepare_dataframe_references(**input_dataframes)
return {**the_rest, "input_dataframes": json.dumps(dataframes_dict)}
def handle_output(raw_output):
output = json.loads(raw_output)
dataframes_dict = output.pop("output_dataframe_references", {})
output_dataframes = dereference_dataframes(dataframes_dict)
the_rest = {k: v for (k, v) in output.items() if k not in output_dataframes}
return {**output_dataframes, **the_rest}
def dereference_dataframes(dataframes_dict):
spark = SparkSession.builder.appName("project").getOrCreate()
return {
name: spark.table("global_temp." + view_name)
for (name, view_name) in dataframes_dict.items()
}
def prepare_dataframe_references(**kwargs):
"""Puts dataframes into the global_temp schema and returns the view names.
Args:
kwargs: key value pairs of names and dataframes
e.g.
"some_df": <DataFrame>,
"another_df": <DataFrame>,
If any value is not a DataFrame, throws an exception.
Returns:
Dict mapping the same input names to view names.
e.g.
{
"some_df": "some_df_fae8f78",
"another_df": "another_df_0a54d6fe", }
"""
input_dataframes = [
{"name": k, "df": v, "view_name": f"{k}_{str(uuid4())[:8]}"}
for (k, v) in kwargs.items()
if isinstance(v, pd.DataFrame)
]
the_rest = {
k: v
for (k, v) in kwargs.items()
if k not in [x["name"] for x in input_dataframes]
}
print("the_rest", the_rest)
if the_rest:
print("also got non dataframe arguments, oops", the_rest)
raise Exception("Oops, got some non dataframe arguments.")
for x in input_dataframes:
x["df"].createOrReplaceGlobalTempView(x["view_name"])
return {x["name"]: x["view_name"] for x in input_dataframes}