I was working through some odd window function behavior Analyzing some feature drift data, I wanted to obtain min, max and mean drift values for features, partitioning on the compare_date here. I would have just done a group by, but I also wanted to get the baseline_date relevant to the largest drift score and so I went with the below approach. But I ended up with some strange results.
from pyspark.sql.window import Window import pyspark.sql.functions as F toy_df = spark.createDataFrame( [{'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.16076398135644604, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.07818495131083669, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.07164427544566881, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.2018208744775895, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.06897468871439233, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.07111383432227428, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.20151850543660316, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.05584133483840621, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.056223672793567, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.10648175064912868, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.03398787644288803, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.027693531284292805, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.12696742943404185, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.07147622765870758, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.07478091185430771, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat8', 'category': 'cat2', 'Drift score': 0.11779958630386245, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat8', 'category': 'cat2', 'Drift score': 0.04240444683921199, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}] ) toy_df.show() +--------------------+-------------+--------+------------+-------+-----+ | Drift score|baseline_date|category|compare_date|feature|group| +--------------------+-------------+--------+------------+-------+-----+ | 0.0| 20191231| cat1| 20230131| feat1| blah| | 0.0| 20220131| cat1| 20230131| feat1| blah| | 0.0| 20220731| cat1| 20230131| feat1| blah| | 0.16076398135644604| 20191231| cat1| 20230131| feat2| blah| | 0.07818495131083669| 20220131| cat1| 20230131| feat2| blah| | 0.07164427544566881| 20220731| cat1| 20230131| feat2| blah| | 0.2018208744775895| 20191231| cat1| 20230131| feat3| blah| | 0.06897468871439233| 20220131| cat1| 20230131| feat3| blah| | 0.07111383432227428| 20220731| cat1| 20230131| feat3| blah| | 0.20151850543660316| 20191231| cat1| 20230131| feat5| blah| | 0.05584133483840621| 20220131| cat1| 20230131| feat5| blah| | 0.056223672793567| 20220731| cat1| 20230131| feat5| blah| | 0.10648175064912868| 20191231| cat1| 20230131| feat6| blah| | 0.03398787644288803| 20220131| cat1| 20230131| feat6| blah| |0.027693531284292805| 20220731| cat1| 20230131| feat6| blah| | 0.12696742943404185| 20191231| cat1| 20230131| feat7| blah| | 0.07147622765870758| 20220131| cat1| 20230131| feat7| blah| | 0.07478091185430771| 20220731| cat1| 20230131| feat7| blah| | 0.11779958630386245| 20191231| cat2| 20230131| feat8| blah| | 0.04240444683921199| 20220131| cat2| 20230131| feat8| blah| +--------------------+-------------+--------+------------+-------+-----+ Applying the window function here, w = Window.partitionBy("group", "feature", "compare_date", ).orderBy(F.col("Drift score").desc()) ( toy_df .withColumn("mean_score", F.round(F.mean("Drift score").over(w), 4)) .withColumn("max_score", F.round(F.max("Drift score").over(w), 4)) .withColumn("min_score", F.round(F.min("Drift score").over(w), 4)) .withColumn("baseline_date_max_score", F.first("baseline_date").over(w)) .withColumn("row_num", F.row_number().over(w)) .where(F.col("row_num") == 1) .drop("row_num") .select("category", "feature", "compare_date", "mean_score", "max_score", "min_score", "baseline_date_max_score") .show() ) +--------+-------+------------+----------+---------+---------+-----------------------+ |category|feature|compare_date|mean_score|max_score|min_score|baseline_date_max_score| +--------+-------+------------+----------+---------+---------+-----------------------+ | cat1| feat1| 20230131| 0.0| 0.0| 0.0| 20191231| | cat1| feat2| 20230131| 0.1608| 0.1608| 0.1608| 20191231| | cat1| feat3| 20230131| 0.2018| 0.2018| 0.2018| 20191231| | cat1| feat5| 20230131| 0.2015| 0.2015| 0.2015| 20191231| | cat1| feat6| 20230131| 0.1065| 0.1065| 0.1065| 20191231| | cat1| feat7| 20230131| 0.127| 0.127| 0.127| 20191231| | cat2| feat8| 20230131| 0.1178| 0.1178| 0.1178| 20191231| +--------+-------+------------+----------+---------+---------+-----------------------+ I was confused why are the min, max and mean all the same. I thought, could it be my data is corrupted and some of my partitions only have one row?
...