In PySpark, find/select maximum (max) row per group can be calculated using Window.partitionBy() function and running row_number() function over window partition, let’s see with a DataFrame example.
1. Prepare Data & DataFrame
First, let’s create the PySpark DataFrame with 3 columns employee_name
, department
and salary
. Column department
contains different departments to do grouping.
from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
yields below output.
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| James| Sales| 3000|
| Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------------+----------+------+
2. PySpark Find Maximum Row per Group in DataFrame
We can select/find the maximum row per group using PySpark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber()
, partitionBy()
and orderBy()
.
This example calculates highest salary
of each department
group.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
.filter(col("row") == 1).drop("row") \
.show()
On the above example, it performs below steps.
- first, Partition the DataFrame on
department
column, which groups all same departments into a group. - Apply
orderBy()
onsalary
column by descending order. - Add a new column
row
by runningrow_number()
function over the partition window.row_number()
function returns a sequential number starting from 1 within a window partition group. - Using the PySpark filter(), just select row == 1, which returns the maximum salary of each group.
- Finally, if a
row
column is not needed, just drop it.
Yields below output.
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
|Michael| Sales| 4600|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
+-------+----------+------+
3. PySpark SQL expression to Find Maximum Row per Group
You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be much familiar to you. In order to use SQL, first, you need to create a temporary view using createOrReplaceTempView()
.
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
This yields the same output as above. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession
, if you wanted to drop this view use spark.catalog.dropTempView("tempViewName")
4. Complete Example
Following is complete example.
from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
.filter(col("row") == 1).drop("row") \
.show()
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
5. Conclusion
In summary, you can find the maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition data per each group, add row_number() to the sorted data, and finally filter to get the first record.
Happy Learning !!
Related Articles
- PySpark Row using on DataFrame and RDD
- Pyspark Select Distinct Rows
- PySpark Get Number of Rows and Columns
- PySpark Select Top N Rows From Each Group
- PySpark Select First Row of Each Group?
- PySpark How to Filter Rows with NULL Values
- PySpark Distinct to Drop Duplicate Rows
- PySpark Drop Rows with NULL or None Values
- PySpark – explode nested array into rows