• Post author:
  • Post category:PySpark
  • Post last modified:March 27, 2024
  • Reading time:10 mins read
You are currently viewing PySpark Find Maximum Row per Group in DataFrame

In PySpark, finding the maximum (max) row per group can be calculated using the Window.partition() function and running the row_number() function over the window partition; let’s see with a DataFrame example in PySpark (Python).

Related: PySpark Window Function Explained with Examples

1. Prepare Data & DataFrame

First, let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


# Import 
from pyspark.sql import SparkSession,Row

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

# Create data
data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

# Create DataFrame
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Yields below output.


# Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

From the above data frame, we would select the below-highlighted rows as these are the maximum salaries of each department.

pyspark find maximum row group

2. PySpark Find Maximum Row per Group in DataFrame

We can select/find the maximum row per group using PySpark SQL or DataFrame API; in this section, we will see with DataFrame API using a window function row_rumber(), partitionBy() and orderBy().

This example calculates the highest salary of each department group. Note that this code assumes you have a DataFrame df with columns named "department" and "salary", and you want to find the maximum value within each group in the "department" column. You may need to adjust column names and data types according to your specific DataFrame structure.


# Imports
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number

# Define a window specification to partition by the department column 
# and order by the salary column
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())

# Use the window function to rank rows within each partition
rank = df.withColumn("row",row_number().over(windowDept))

# Filter the DataFrame to keep only the rows with rank 1 
# (i.e., the maximum salary within each group)
rank.filter(col("row") == 1).drop("row") \
  .show()

In the above example, it performs the below steps.

  • First, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column by descending order.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the PySpark filter(), just select row == 1, which returns the maximum salary of each group.
  • Finally, if a row column is not needed, just drop it.

Yields below output.


# Output:
+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|Michael|     Sales|  4600|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
+-------+----------+------+

3. PySpark SQL expression to Find Maximum Row per Group

You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView(). Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession. If you want to drop this view use spark.catalog.dropTempView("tempViewName")


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

This yields the same output as above.

This Spark SQL query performs the following operations:

  1. Subquery with Window Function:
    • The inner subquery (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn FROM EMP) tmp selects all columns (*) from the EMP table/view and adds a new column rn. The row_number() window function assigns a sequential number to each row within each department partition based on the salary column in descending order.
  2. Outer Query:
    • The outer query select Name, Department, Salary from (subquery) tmp where rn = 1 selects columns Name, Department, and Salary from the subquery result tmp where the rn (row number) is equal to 1.
    • The rn = 1 condition filters the result set to only retain rows where the row number within each department partition is 1. This effectively retrieves the row with the highest salary within each department.
  3. spark.sql().show():
    • Finally, the .show() method displays the resulting DataFrame after executing the SQL query.

4. Complete Example

Following is the complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") == 1).drop("row") \
  .show()

df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

5. Conclusion

In summary, you can find the maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition data per each group, add row_number() to the sorted data, and finally filter to get the first record.

Happy Learning !!

References

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium