In PySpark, finding the maximum (max) row per group can be calculated using the Window.partition() function and running the row_number() function over the window partition; let’s see with a DataFrame example in PySpark (Python).
Related: PySpark Window Function Explained with Examples
1. Prepare Data & DataFrame
First, let’s create the PySpark DataFrame with 3 columns employee_name
, department
and salary
. Column department
contains different departments to do grouping.
# Import
from pyspark.sql import SparkSession,Row
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
# Create data
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
# Create DataFrame
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
Yields below output.
# Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| James| Sales| 3000|
| Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------------+----------+------+
From the above data frame, we would select the below-highlighted rows as these are the maximum salaries of each department.
2. PySpark Find Maximum Row per Group in DataFrame
We can select/find the maximum row per group using PySpark SQL or DataFrame API; in this section, we will see with DataFrame API using a window function row_rumber()
, partitionBy()
and orderBy()
.
This example calculates the highest salary
of each department
group. Note that this code assumes you have a DataFrame df
with columns named "department"
and "salary"
, and you want to find the maximum value within each group in the "department"
column. You may need to adjust column names and data types according to your specific DataFrame structure.
# Imports
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
# Define a window specification to partition by the department column
# and order by the salary column
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
# Use the window function to rank rows within each partition
rank = df.withColumn("row",row_number().over(windowDept))
# Filter the DataFrame to keep only the rows with rank 1
# (i.e., the maximum salary within each group)
rank.filter(col("row") == 1).drop("row") \
.show()
In the above example, it performs the below steps.
- First, Partition the DataFrame on
department
column, which groups all same departments into a group. - Apply
orderBy()
onsalary
column by descending order. - Add a new column
row
by runningrow_number()
function over the partition window.row_number()
function returns a sequential number starting from 1 within a window partition group. - Using the PySpark filter(), just select row == 1, which returns the maximum salary of each group.
- Finally, if a
row
column is not needed, just drop it.
Yields below output.
# Output:
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
|Michael| Sales| 4600|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
+-------+----------+------+
3. PySpark SQL expression to Find Maximum Row per Group
You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView()
. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession
. If you want to drop this view use spark.catalog.dropTempView("tempViewName")
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
This yields the same output as above.
This Spark SQL query performs the following operations:
- Subquery with Window Function:
- The inner subquery
(select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn FROM EMP) tmp
selects all columns (*
) from theEMP
table/view and adds a new columnrn
. Therow_number()
window function assigns a sequential number to each row within eachdepartment
partition based on thesalary
column in descending order.
- The inner subquery
- Outer Query:
- The outer query
select Name, Department, Salary from (subquery) tmp where rn = 1
selects columnsName
,Department
, andSalary
from the subquery resulttmp
where thern
(row number) is equal to 1. - The
rn = 1
condition filters the result set to only retain rows where the row number within eachdepartment
partition is 1. This effectively retrieves the row with the highest salary within each department.
- The outer query
spark.sql().show()
:- Finally, the
.show()
method displays the resulting DataFrame after executing the SQL query.
- Finally, the
4. Complete Example
Following is the complete example.
from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
.filter(col("row") == 1).drop("row") \
.show()
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
5. Conclusion
In summary, you can find the maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition data per each group, add row_number() to the sorted data, and finally filter to get the first record.
Happy Learning !!
Related Articles
- PySpark Row using on DataFrame and RDD
- Pyspark Select Distinct Rows
- PySpark Get Number of Rows and Columns
- PySpark Select Top N Rows From Each Group
- PySpark Select First Row of Each Group?
- PySpark How to Filter Rows with NULL Values
- PySpark Distinct to Drop Duplicate Rows
- PySpark Drop Rows with NULL or None Values
- PySpark – explode nested array into rows