In PySpark, Finding or Selecting the Top N rows per each group can be calculated by partitioning the data by window. Use the Window.partitionBy() function, running the row_number() function over the grouped partition, and finally, filtering the rows to get the top N rows. Let’s see with a DataFrame example.
Below is a quick snippet that give you top 2 rows for each group.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
.filter(col("row") <= 2) \
.show()
Alternatively, you can also get using PySpark SQL
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn <= 2").show()
In the above snippet, I am getting the top 2 salaries from each department, just change the filter condition to get the top 5 or 10..n records. Let’s see what’s happening at each step with the actual example.
Select Top N Rows From Each Group in PySpark Example
let’s create the PySpark DataFrame with 3 columns employee_name
, department
and salary
. Column department
contains different departments to do grouping.
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
Yields below output
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
| James| Sales| 3000|
|Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------+----------+------+
Window.partitionBy()
Partition the DataFrame on deparment
column using Window.partitionBy()
, sort by salary
column for each group by descending order and using row_number() function add sequence number to the DataFrame of each group and name the column row
.
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df2=df.withColumn("row",row_number().over(windowDept))
df2.show()
Yields below output
+-------+----------+------+---+
| Name|Department|Salary|row|
+-------+----------+------+---+
|Michael| Sales| 4600| 1|
| Robert| Sales| 4100| 2|
| James| Sales| 3000| 3|
| Jen| Finance| 3900| 1|
| Scott| Finance| 3300| 2|
| Maria| Finance| 3000| 3|
| Raman| Finance| 3000| 4|
| Jeff| Marketing| 3000| 1|
| Kumar| Marketing| 2000| 2|
+-------+----------+------+---+
Filter the Result that has a row number less than equal to 2
Now filter the DataFrame to return top N rows. here It return first 2 records for each group. Change the value of 2 with the value you want.
df3=df2.filter(col("row") <= 2)
df3.show()
Yields below output.
+-------+----------+------+---+
| Name|Department|Salary|row|
+-------+----------+------+---+
|Michael| Sales| 4600| 1|
| Robert| Sales| 4100| 2|
| Jen| Finance| 3900| 1|
| Scott| Finance| 3300| 2|
| Jeff| Marketing| 3000| 1|
| Kumar| Marketing| 2000| 2|
+-------+----------+------+---+
Drop Row Number Column
Finally remove the column row
that has row number, in case if you need this row number for any further processing then you can keep this column.
df3.drop("row").show()
Yields below output.
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
|Michael| Sales| 4600|
| Robert| Sales| 4100|
| Jen| Finance| 3900|
| Scott| Finance| 3300|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------+----------+------+
Complete Example
from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
data = [("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000)]
df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept))
.filter(col("row") <= 2)
.drop("row").show()
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn <= 2").show()
Conclusion
In summary, you can select/find the top N rows for each group in PySpark DataFrame by partitioning the data by group using Window.partitionBy(), sort the partition data per each group, add row_number() to the sorted data and finally filter to get the top n records.
Happy Learning !!
Related Articles
- PySpark Retrieve DataType & Column Names of DataFrame
- PySpark Add New Column with Row Number
- PySpark Parse JSON from String Column | TEXT File
- PySpark Select Nested struct Columns
- PySpark Check Column Exists in DataFrame
- PySpark Convert DataFrame Columns to MapType (Dict)
- PySpark Convert Dictionary/Map to Multiple Columns
- Pyspark Select Distinct Rows
- PySpark Select First Row of Each Group?