PySpark Find Maximum Row per Group in DataFrame

In PySpark, find/select maximum (max) row per group can be calculated using Window.partitionBy() function and running row_number() function over window partition, let’s see with a DataFrame example.

1. Prepare Data & DataFrame

First, let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

yields below output.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. PySpark Find Maximum Row per Group in DataFrame

We can select/find the maximum row per group using PySpark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber(), partitionBy() and orderBy().

This example calculates highest salary of each department group.


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") == 1).drop("row") \
  .show()

On the above example, it performs below steps.

  • first, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column by descending order.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the PySpark filter(), just select row == 1, which returns the maximum salary of each group.
  • Finally, if a row column is not needed, just drop it.

Yields below output.


+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|Michael|     Sales|  4600|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
+-------+----------+------+

3. PySpark SQL expression to Find Maximum Row per Group

You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be much familiar to you. In order to use SQL, first, you need to create a temporary view using createOrReplaceTempView().


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

This yields the same output as above. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession, if you wanted to drop this view use spark.catalog.dropTempView("tempViewName")

4. Complete Example

Following is complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") == 1).drop("row") \
  .show()

df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

5. Conclusion

In summary, you can find the maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition data per each group, add row_number() to the sorted data, and finally filter to get the first record.

Happy Learning !!

References

PySpark Find Maximum Row Group

NNK

SparkByExamples.com is a Big Data and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment Read more ..

Leave a Reply

You are currently viewing PySpark Find Maximum Row per Group in DataFrame