You are currently viewing PySpark Select Top N Rows From Each Group

In PySpark, Finding or Selecting the Top N rows per each group can be calculated by partitioning the data by window. Use the Window.partitionBy() function, running the row_number() function over the grouped partition, and finally, filtering the rows to get the top N rows. Let’s see with a DataFrame example.

Below is a quick snippet that give you top 2 rows for each group.


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") <= 2) \
  .show()

Alternatively, you can also get using PySpark SQL


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn <= 2").show()

In the above snippet, I am getting the top 2 salaries from each department, just change the filter condition to get the top 5 or 10..n records. Let’s see what’s happening at each step with the actual example.

Select Top N Rows From Each Group in PySpark Example

let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Yields below output


+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|  James|     Sales|  3000|
|Michael|     Sales|  4600|
| Robert|     Sales|  4100|
|  Maria|   Finance|  3000|
|  Raman|   Finance|  3000|
|  Scott|   Finance|  3300|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
|  Kumar| Marketing|  2000|
+-------+----------+------+

Window.partitionBy()

Partition the DataFrame on deparment column using Window.partitionBy(), sort by salary column for each group by descending order and using row_number() function add sequence number to the DataFrame of each group and name the column row.


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df2=df.withColumn("row",row_number().over(windowDept))
df2.show()

Yields below output


+-------+----------+------+---+
|   Name|Department|Salary|row|
+-------+----------+------+---+
|Michael|     Sales|  4600|  1|
| Robert|     Sales|  4100|  2|
|  James|     Sales|  3000|  3|
|    Jen|   Finance|  3900|  1|
|  Scott|   Finance|  3300|  2|
|  Maria|   Finance|  3000|  3|
|  Raman|   Finance|  3000|  4|
|   Jeff| Marketing|  3000|  1|
|  Kumar| Marketing|  2000|  2|
+-------+----------+------+---+

Filter the Result that has a row number less than equal to 2

Now filter the DataFrame to return top N rows. here It return first 2 records for each group. Change the value of 2 with the value you want.


df3=df2.filter(col("row") <= 2)
df3.show()

Yields below output.


+-------+----------+------+---+
|   Name|Department|Salary|row|
+-------+----------+------+---+
|Michael|     Sales|  4600|  1|
| Robert|     Sales|  4100|  2|
|    Jen|   Finance|  3900|  1|
|  Scott|   Finance|  3300|  2|
|   Jeff| Marketing|  3000|  1|
|  Kumar| Marketing|  2000|  2|
+-------+----------+------+---+

Drop Row Number Column

Finally remove the column row that has row number, in case if you need this row number for any further processing then you can keep this column.


df3.drop("row").show()

Yields below output.


+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|Michael|     Sales|  4600|
| Robert|     Sales|  4100|
|    Jen|   Finance|  3900|
|  Scott|   Finance|  3300|
|   Jeff| Marketing|  3000|
|  Kumar| Marketing|  2000|
+-------+----------+------+

Complete Example


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept))
  .filter(col("row") <= 2)
  .drop("row").show()

df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn <= 2").show()

Conclusion

In summary, you can select/find the top N rows for each group in PySpark DataFrame by partitioning the data by group using Window.partitionBy(), sort the partition data per each group, add row_number() to the sorted data and finally filter to get the top n records.

Happy Learning !!

References