To find the maximum row per group in PySpark, you can utilize the window function. First, partition the DataFrame by the grouping column(s). Then, apply a window function, such as max()
, to the desired column(s). This function computes the maximum value within each partition. Finally, filter the DataFrame to retain only rows where the value matches the maximum value within its respective group.
You can also achieve this using the row_number() function and window partition.
Related: PySpark Window Function Explained with Examples
1. Prepare Data & DataFrame
Let’s prepare the data and create a DataFrame.
# Import
from pyspark.sql import SparkSession,Row
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
# Create data
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
# Create DataFrame
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
Yields below output.
# Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| James| Sales| 3000|
| Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------------+----------+------+
I will use this DataFrame to select the highest salary per department, the resultant DataFrame will contain only the highlighted rows.
2. PySpark Find Maximum Row per Group in DataFrame
To calculate the maximum row per group using PySpark’s DataFrame API, first, create a window partitioned by the grouping column(s), second, Apply the row_number()
window function to assign a unique sequential number to each row within each partition, ordered by the column(s) of interest. Finally, filter the DataFrame to retain only rows where the row number equals 1, indicating the maximum row per group.
The folloiwng example calculates the highest salary
of each department
group. Note that this code assumes you have a DataFrame df
with columns named "department"
and "salary"
, and you want to find the maximum value within each group in the "department"
column.
# Imports
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
# Define a window specification to partition by the department column
# and order by the salary column
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
# Use the window function to rank rows within each partition
rank = df.withColumn("row",row_number().over(windowDept))
# Filter the DataFrame to keep only the rows with rank 1
# (i.e., the maximum salary within each group)
rank.filter(col("row") == 1).drop("row") \
.show()
Yields below output.
# Output:
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
|Michael| Sales| 4600|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
+-------+----------+------+
3. Using max() Aggregate Function
To obtain the highest salary per department using the provided DataFrame example in PySpark, you can use window functions along with the max()
aggregation function.
from pyspark.sql import Window
from pyspark.sql.functions import col, max
# Define a window specification partitioned by department
windowSpec = Window.partitionBy("Department")
# Add a column for the maximum salary within each department
df_max_salary = df.withColumn("max_salary", max(col("Salary")).over(windowSpec))
# Filter the DataFrame to retain only rows where the salary matches the maximum salary within its department
result = df_max_salary.filter(col("Salary") == col("max_salary")).drop("max_salary")
# Show the result
result.show()
4. PySpark SQL expression to Find Maximum Row per Group
You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView()
. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession
. If you want to drop this view use spark.catalog.dropTempView("tempViewName")
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
This yields the same output as above.
This Spark SQL query performs the following operations:
- Subquery with Window Function:
- The inner subquery
(select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn FROM EMP) tmp
selects all columns (*
) from theEMP
table/view and adds a new columnrn
. Therow_number()
window function assigns a sequential number to each row within eachdepartment
partition based on thesalary
column in descending order.
- The inner subquery
- Outer Query:
- The outer query
select Name, Department, Salary from (subquery) tmp where rn = 1
selects columnsName
,Department
, andSalary
from the subquery resulttmp
where thern
(row number) is equal to 1. - The
rn = 1
condition filters the result set to only retain rows where the row number within eachdepartment
partition is 1. This effectively retrieves the row with the highest salary within each department.
- The outer query
spark.sql().show()
:- Finally, the
.show()
method displays the resulting DataFrame after executing the SQL query.
- Finally, the
4. Complete Example
Following is the complete example.
from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
.filter(col("row") == 1).drop("row") \
.show()
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
5. Conclusion
In conclusion, identifying the maximum row per group in PySpark using methods such as row_number()
, Window.partitionBy()
, and orderBy()
offers a streamlined and efficient approach to data analysis. By leveraging window functions and DataFrame APIs, analysts can easily partition data based on specific criteria and order it within each partition to determine the maximum row.
Happy Learning !!
Related Articles
- PySpark Row using on DataFrame and RDD
- Pyspark Select Distinct Rows
- PySpark Get Number of Rows and Columns
- PySpark Select Top N Rows From Each Group
- PySpark Select First Row of Each Group?
- PySpark How to Filter Rows with NULL Values
- PySpark Distinct to Drop Duplicate Rows
- PySpark Drop Rows with NULL or None Values
- PySpark – explode nested array into rows