• Post author:
  • Post category:PySpark
  • Post last modified:March 27, 2024
  • Reading time:8 mins read
You are currently viewing PySpark Select Top N Rows From Each Group

In PySpark Find/Select Top N rows from each group can be calculated by partition the data by window using Window.partitionBy() function, running row_number() function over the grouped partition, and finally filter the rows to get top N rows, let’s see with a DataFrame example.

Below is a quick snippet that give you top 2 rows for each group.


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept)) \
  .filter(col("row") <= 2) \
  .show()

Alternatively, you can also get using PySpark SQL


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn <= 2").show()

In the above snippet, I am getting the top 2 salaries from each department, just change the filter condition to get the top 5 or 10..n records. Let’s see what’s happening at each step with the actual example.

Select Top N Rows From Each Group in PySpark Example

let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Yields below output


+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|  James|     Sales|  3000|
|Michael|     Sales|  4600|
| Robert|     Sales|  4100|
|  Maria|   Finance|  3000|
|  Raman|   Finance|  3000|
|  Scott|   Finance|  3300|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
|  Kumar| Marketing|  2000|
+-------+----------+------+

Window.partitionBy()

Partition the DataFrame on deparment column using Window.partitionBy(), sort by salary column for each group by descending order and using row_number() function add sequence number to the DataFrame of each group and name the column row.


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df2=df.withColumn("row",row_number().over(windowDept))
df2.show()

Yields below output


+-------+----------+------+---+
|   Name|Department|Salary|row|
+-------+----------+------+---+
|Michael|     Sales|  4600|  1|
| Robert|     Sales|  4100|  2|
|  James|     Sales|  3000|  3|
|    Jen|   Finance|  3900|  1|
|  Scott|   Finance|  3300|  2|
|  Maria|   Finance|  3000|  3|
|  Raman|   Finance|  3000|  4|
|   Jeff| Marketing|  3000|  1|
|  Kumar| Marketing|  2000|  2|
+-------+----------+------+---+

Filter the Result that has a row number less than equal to 2

Now filter the DataFrame to return top N rows. here It return first 2 records for each group. Change the value of 2 with the value you want.


df3=df2.filter(col("row") <= 2)
df3.show()

Yields below output.


+-------+----------+------+---+
|   Name|Department|Salary|row|
+-------+----------+------+---+
|Michael|     Sales|  4600|  1|
| Robert|     Sales|  4100|  2|
|    Jen|   Finance|  3900|  1|
|  Scott|   Finance|  3300|  2|
|   Jeff| Marketing|  3000|  1|
|  Kumar| Marketing|  2000|  2|
+-------+----------+------+---+

Drop Row Number Column

Finally remove the column row that has row number, in case if you need this row number for any further processing then you can keep this column.


df3.drop("row").show()

Yields below output.


+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|Michael|     Sales|  4600|
| Robert|     Sales|  4100|
|    Jen|   Finance|  3900|
|  Scott|   Finance|  3300|
|   Jeff| Marketing|  3000|
|  Kumar| Marketing|  2000|
+-------+----------+------+

Complete Example of PtSpark Select Top N Rows From Each Group


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
windowDept = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(windowDept))
  .filter(col("row") <= 2)
  .drop("row").show()

df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn <= 2").show()

Conclusion

In summary, you can select/find the top N rows for each group in PySpark DataFrame by partitioning the data by group using Window.partitionBy(), sort the partition data per each group, add row_number() to the sorted data and finally filter to get the top n records.

Happy Learning !!

References

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium