I was working through some odd window function behavior

Analyzing some feature drift data, I wanted to obtain min, max and mean drift values for features, partitioning on the compare_date here. I would have just done a group by, but I also wanted to get the baseline_date relevant to the largest drift score and so I went with the below approach. But I ended up with some strange results.

from pyspark.sql.window import Window
import pyspark.sql.functions as F

toy_df = spark.createDataFrame(
    [{'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat1', 'category': 'cat1', 'Drift score': 0.0, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.16076398135644604, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.07818495131083669, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat2', 'category': 'cat1', 'Drift score': 0.07164427544566881, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.2018208744775895, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.06897468871439233, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat3', 'category': 'cat1', 'Drift score': 0.07111383432227428, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.20151850543660316, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.05584133483840621, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat5', 'category': 'cat1', 'Drift score': 0.056223672793567, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.10648175064912868, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.03398787644288803, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat6', 'category': 'cat1', 'Drift score': 0.027693531284292805, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.12696742943404185, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.07147622765870758, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}, {'feature': 'feat7', 'category': 'cat1', 'Drift score': 0.07478091185430771, 'group': 'blah', 'baseline_date': '20220731', 'compare_date': '20230131'}, {'feature': 'feat8', 'category': 'cat2', 'Drift score': 0.11779958630386245, 'group': 'blah', 'baseline_date': '20191231', 'compare_date': '20230131'}, {'feature': 'feat8', 'category': 'cat2', 'Drift score': 0.04240444683921199, 'group': 'blah', 'baseline_date': '20220131', 'compare_date': '20230131'}]
)
toy_df.show()

+--------------------+-------------+--------+------------+-------+-----+
|         Drift score|baseline_date|category|compare_date|feature|group|
+--------------------+-------------+--------+------------+-------+-----+
|                 0.0|     20191231|    cat1|    20230131|  feat1| blah|
|                 0.0|     20220131|    cat1|    20230131|  feat1| blah|
|                 0.0|     20220731|    cat1|    20230131|  feat1| blah|
| 0.16076398135644604|     20191231|    cat1|    20230131|  feat2| blah|
| 0.07818495131083669|     20220131|    cat1|    20230131|  feat2| blah|
| 0.07164427544566881|     20220731|    cat1|    20230131|  feat2| blah|
|  0.2018208744775895|     20191231|    cat1|    20230131|  feat3| blah|
| 0.06897468871439233|     20220131|    cat1|    20230131|  feat3| blah|
| 0.07111383432227428|     20220731|    cat1|    20230131|  feat3| blah|
| 0.20151850543660316|     20191231|    cat1|    20230131|  feat5| blah|
| 0.05584133483840621|     20220131|    cat1|    20230131|  feat5| blah|
|   0.056223672793567|     20220731|    cat1|    20230131|  feat5| blah|
| 0.10648175064912868|     20191231|    cat1|    20230131|  feat6| blah|
| 0.03398787644288803|     20220131|    cat1|    20230131|  feat6| blah|
|0.027693531284292805|     20220731|    cat1|    20230131|  feat6| blah|
| 0.12696742943404185|     20191231|    cat1|    20230131|  feat7| blah|
| 0.07147622765870758|     20220131|    cat1|    20230131|  feat7| blah|
| 0.07478091185430771|     20220731|    cat1|    20230131|  feat7| blah|
| 0.11779958630386245|     20191231|    cat2|    20230131|  feat8| blah|
| 0.04240444683921199|     20220131|    cat2|    20230131|  feat8| blah|
+--------------------+-------------+--------+------------+-------+-----+

Applying the window function here,

w = Window.partitionBy("group", "feature", "compare_date", ).orderBy(F.col("Drift score").desc())
(
    toy_df
    .withColumn("mean_score", F.round(F.mean("Drift score").over(w), 4))
    .withColumn("max_score", F.round(F.max("Drift score").over(w), 4))
    .withColumn("min_score", F.round(F.min("Drift score").over(w), 4))
    .withColumn("baseline_date_max_score", F.first("baseline_date").over(w))
    .withColumn("row_num", F.row_number().over(w))
    .where(F.col("row_num") == 1)
    .drop("row_num")
    .select("category", "feature", "compare_date", "mean_score", "max_score", "min_score", "baseline_date_max_score")
    .show()
)

+--------+-------+------------+----------+---------+---------+-----------------------+
|category|feature|compare_date|mean_score|max_score|min_score|baseline_date_max_score|
+--------+-------+------------+----------+---------+---------+-----------------------+
|    cat1|  feat1|    20230131|       0.0|      0.0|      0.0|               20191231|
|    cat1|  feat2|    20230131|    0.1608|   0.1608|   0.1608|               20191231|
|    cat1|  feat3|    20230131|    0.2018|   0.2018|   0.2018|               20191231|
|    cat1|  feat5|    20230131|    0.2015|   0.2015|   0.2015|               20191231|
|    cat1|  feat6|    20230131|    0.1065|   0.1065|   0.1065|               20191231|
|    cat1|  feat7|    20230131|     0.127|    0.127|    0.127|               20191231|
|    cat2|  feat8|    20230131|    0.1178|   0.1178|   0.1178|               20191231|
+--------+-------+------------+----------+---------+---------+-----------------------+

I was confused why are the min, max and mean all the same. I thought, could it be my data is corrupted and some of my partitions only have one row?

I took out the filter by “row_num”, to try debugging,

w = Window.partitionBy("group", "feature", "compare_date", ).orderBy(F.col("Drift score").desc())
(
    toy_df
    .withColumn("mean_score", F.round(F.mean("Drift score").over(w), 4))
    .withColumn("max_score", F.round(F.max("Drift score").over(w), 4))
    .withColumn("min_score", F.round(F.min("Drift score").over(w), 4))
    .withColumn("baseline_date_max_score", F.first("baseline_date").over(w))
    .withColumn("row_num", F.row_number().over(w))

    .select("category", "feature", "compare_date", "Drift score" , "mean_score", "max_score", "min_score",  )
    .show()
)

+--------+-------+------------+--------------------+----------+---------+---------+
|category|feature|compare_date|         Drift score|mean_score|max_score|min_score|
+--------+-------+------------+--------------------+----------+---------+---------+
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|
|    cat1|  feat2|    20230131| 0.16076398135644604|    0.1608|   0.1608|   0.1608|
|    cat1|  feat2|    20230131| 0.07818495131083669|    0.1195|   0.1608|   0.0782|
|    cat1|  feat2|    20230131| 0.07164427544566881|    0.1035|   0.1608|   0.0716|
|    cat1|  feat3|    20230131|  0.2018208744775895|    0.2018|   0.2018|   0.2018|
|    cat1|  feat3|    20230131| 0.07111383432227428|    0.1365|   0.2018|   0.0711|
|    cat1|  feat3|    20230131| 0.06897468871439233|     0.114|   0.2018|    0.069|
|    cat1|  feat5|    20230131| 0.20151850543660316|    0.2015|   0.2015|   0.2015|
|    cat1|  feat5|    20230131|   0.056223672793567|    0.1289|   0.2015|   0.0562|
|    cat1|  feat5|    20230131| 0.05584133483840621|    0.1045|   0.2015|   0.0558|
|    cat1|  feat6|    20230131| 0.10648175064912868|    0.1065|   0.1065|   0.1065|
|    cat1|  feat6|    20230131| 0.03398787644288803|    0.0702|   0.1065|    0.034|
|    cat1|  feat6|    20230131|0.027693531284292805|    0.0561|   0.1065|   0.0277|
|    cat1|  feat7|    20230131| 0.12696742943404185|     0.127|    0.127|    0.127|
|    cat1|  feat7|    20230131| 0.07478091185430771|    0.1009|    0.127|   0.0748|
|    cat1|  feat7|    20230131| 0.07147622765870758|    0.0911|    0.127|   0.0715|
|    cat2|  feat8|    20230131| 0.11779958630386245|    0.1178|   0.1178|   0.1178|
|    cat2|  feat8|    20230131| 0.04240444683921199|    0.0801|   0.1178|   0.0424|
+--------+-------+------------+--------------------+----------+---------+---------+

Now this looked even more weird, since somehow the min, max and mean were different for different rows in the partitions.

I forget where, I read somewhere that the use of the orderBy, which I needed for one of the columns, was creating a weird situation for the min max mean columns, so I took that out.

Ended up with

w = Window.partitionBy("group", "feature", "compare_date", )
(
    toy_df
    .withColumn("mean_score", F.round(F.mean("Drift score").over(w), 4))
    .withColumn("max_score", F.round(F.max("Drift score").over(w), 4))
    .withColumn("min_score", F.round(F.min("Drift score").over(w), 4))
    .withColumn("baseline_date_max_score", F.first("baseline_date").over(w.orderBy(F.col("Drift score").desc())))
    .withColumn("row_num", F.row_number().over(w.orderBy("Drift score")))
    .select("category", "feature", "compare_date", "Drift score" , "mean_score", "max_score", "min_score", "baseline_date_max_score" )

    .show()
)

+--------+-------+------------+--------------------+----------+---------+---------+-----------------------+
|category|feature|compare_date|         Drift score|mean_score|max_score|min_score|baseline_date_max_score|
+--------+-------+------------+--------------------+----------+---------+---------+-----------------------+
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|               20191231|
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|               20191231|
|    cat1|  feat1|    20230131|                 0.0|       0.0|      0.0|      0.0|               20191231|
|    cat1|  feat2|    20230131| 0.16076398135644604|    0.1035|   0.1608|   0.0716|               20191231|
|    cat1|  feat2|    20230131| 0.07818495131083669|    0.1035|   0.1608|   0.0716|               20191231|
|    cat1|  feat2|    20230131| 0.07164427544566881|    0.1035|   0.1608|   0.0716|               20191231|
|    cat1|  feat3|    20230131|  0.2018208744775895|     0.114|   0.2018|    0.069|               20191231|
|    cat1|  feat3|    20230131| 0.07111383432227428|     0.114|   0.2018|    0.069|               20191231|
|    cat1|  feat3|    20230131| 0.06897468871439233|     0.114|   0.2018|    0.069|               20191231|
|    cat1|  feat5|    20230131| 0.20151850543660316|    0.1045|   0.2015|   0.0558|               20191231|
|    cat1|  feat5|    20230131|   0.056223672793567|    0.1045|   0.2015|   0.0558|               20191231|
|    cat1|  feat5|    20230131| 0.05584133483840621|    0.1045|   0.2015|   0.0558|               20191231|
|    cat1|  feat6|    20230131| 0.10648175064912868|    0.0561|   0.1065|   0.0277|               20191231|
|    cat1|  feat6|    20230131| 0.03398787644288803|    0.0561|   0.1065|   0.0277|               20191231|
|    cat1|  feat6|    20230131|0.027693531284292805|    0.0561|   0.1065|   0.0277|               20191231|
|    cat1|  feat7|    20230131| 0.12696742943404185|    0.0911|    0.127|   0.0715|               20191231|
|    cat1|  feat7|    20230131| 0.07478091185430771|    0.0911|    0.127|   0.0715|               20191231|
|    cat1|  feat7|    20230131| 0.07147622765870758|    0.0911|    0.127|   0.0715|               20191231|
|    cat2|  feat8|    20230131| 0.11779958630386245|    0.0801|   0.1178|   0.0424|               20191231|
|    cat2|  feat8|    20230131| 0.04240444683921199|    0.0801|   0.1178|   0.0424|               20191231|
+--------+-------+------------+--------------------+----------+---------+---------+-----------------------+

Finally

Can now filter out the non aggregate rows

w = Window.partitionBy("group", "feature", "compare_date", )
(
    toy_df
    .withColumn("mean_score", F.round(F.mean("Drift score").over(w), 4))
    .withColumn("max_score", F.round(F.max("Drift score").over(w), 4))
    .withColumn("min_score", F.round(F.min("Drift score").over(w), 4))
    .withColumn("baseline_date_max_score", F.first("baseline_date").over(w.orderBy(F.col("Drift score").desc())))
    .withColumn("row_num", F.row_number().over(w.orderBy("Drift score")))
    .where(F.col("row_num") == 1)
    .drop("row_num")
    .select("category", "feature", "compare_date", "mean_score", "max_score", "min_score", "baseline_date_max_score" )

    .show()
)

+--------+-------+------------+----------+---------+---------+-----------------------+
|category|feature|compare_date|mean_score|max_score|min_score|baseline_date_max_score|
+--------+-------+------------+----------+---------+---------+-----------------------+
|    cat1|  feat1|    20230131|       0.0|      0.0|      0.0|               20191231|
|    cat1|  feat2|    20230131|    0.1035|   0.1608|   0.0716|               20191231|
|    cat1|  feat3|    20230131|     0.114|   0.2018|    0.069|               20191231|
|    cat1|  feat5|    20230131|    0.1045|   0.2015|   0.0558|               20191231|
|    cat1|  feat6|    20230131|    0.0561|   0.1065|   0.0277|               20191231|
|    cat1|  feat7|    20230131|    0.0911|    0.127|   0.0715|               20191231|
|    cat2|  feat8|    20230131|    0.0801|   0.1178|   0.0424|               20191231|
+--------+-------+------------+----------+---------+---------+-----------------------+

What explains this odd behavior?

Face palm !

I did not notice this at the time, but a colleague pointed out 🤦‍♂️ that although min, max, mean don’t require orderBy, but when providing it, they will offer cumulative quantities. Indeed spot checking this he was right ! Nice 😀