You are currently viewing PySpark Find Maximum Row per Group in DataFrame

To find the maximum row per group in PySpark, you can utilize the window function. First, partition the DataFrame by the grouping column(s). Then, apply a window function, such as max(), to the desired column(s). This function computes the maximum value within each partition. Finally, filter the DataFrame to retain only rows where the value matches the maximum value within its respective group.

Advertisements

You can also achieve this using the row_number() function and window partition.

Related: PySpark Window Function Explained with Examples

1. Prepare Data & DataFrame

Let’s prepare the data and create a DataFrame.


# Import 
from pyspark.sql import SparkSession,Row

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

# Create data
data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

# Create DataFrame
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Yields below output.


# Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

I will use this DataFrame to select the highest salary per department, the resultant DataFrame will contain only the highlighted rows.

pyspark find maximum row group

2. PySpark Find Maximum Row per Group in DataFrame

To calculate the maximum row per group using PySpark’s DataFrame API, first, create a window partitioned by the grouping column(s), second, Apply the row_number() window function to assign a unique sequential number to each row within each partition, ordered by the column(s) of interest. Finally, filter the DataFrame to retain only rows where the row number equals 1, indicating the maximum row per group.

The folloiwng example calculates the highest salary of each department group. Note that this code assumes you have a DataFrame df with columns named "department" and "salary", and you want to find the maximum value within each group in the "department" column.


# Imports
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number

# Define a window specification to partition by the department column 
# and order by the salary column
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())

# Use the window function to rank rows within each partition
rank = df.withColumn("row",row_number().over(windowDept))

# Filter the DataFrame to keep only the rows with rank 1 
# (i.e., the maximum salary within each group)
rank.filter(col("row") == 1).drop("row") \
  .show()

Yields below output.


# Output:
+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|Michael|     Sales|  4600|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
+-------+----------+------+

3. Using max() Aggregate Function

To obtain the highest salary per department using the provided DataFrame example in PySpark, you can use window functions along with the max() aggregation function.


from pyspark.sql import Window
from pyspark.sql.functions import col, max

# Define a window specification partitioned by department
windowSpec = Window.partitionBy("Department")

# Add a column for the maximum salary within each department
df_max_salary = df.withColumn("max_salary", max(col("Salary")).over(windowSpec))

# Filter the DataFrame to retain only rows where the salary matches the maximum salary within its department
result = df_max_salary.filter(col("Salary") == col("max_salary")).drop("max_salary")

# Show the result
result.show()

4. PySpark SQL expression to Find Maximum Row per Group

You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView(). Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession. If you want to drop this view use spark.catalog.dropTempView("tempViewName")


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

This yields the same output as above.

This Spark SQL query performs the following operations:

  1. Subquery with Window Function:
    • The inner subquery (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn FROM EMP) tmp selects all columns (*) from the EMP table/view and adds a new column rn. The row_number() window function assigns a sequential number to each row within each department partition based on the salary column in descending order.
  2. Outer Query:
    • The outer query select Name, Department, Salary from (subquery) tmp where rn = 1 selects columns Name, Department, and Salary from the subquery result tmp where the rn (row number) is equal to 1.
    • The rn = 1 condition filters the result set to only retain rows where the row number within each department partition is 1. This effectively retrieves the row with the highest salary within each department.
  3. spark.sql().show():
    • Finally, the .show() method displays the resulting DataFrame after executing the SQL query.

4. Complete Example

Following is the complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") == 1).drop("row") \
  .show()

df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

5. Conclusion


In conclusion, identifying the maximum row per group in PySpark using methods such as row_number(), Window.partitionBy(), and orderBy() offers a streamlined and efficient approach to data analysis. By leveraging window functions and DataFrame APIs, analysts can easily partition data based on specific criteria and order it within each partition to determine the maximum row.

Happy Learning !!

References