You are currently viewing PySpark Select First Row of Each Group?

In PySpark, you can select the first row of each group using the window function row_number() along with the Window.partitionBy() method. First, partition the DataFrame by the desired grouping column(s) using partitionBy(), then order the rows within each partition based on a specified order. Apply the row_number() function to generate row numbers for each partition. Finally, filter the DataFrame to retain rows where the row number equals 1, indicating the first row within each group.

Advertisements

1. Prepare Data & DataFrame

Before we start let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Outputs below table.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. PySpark Select First Row From every Group

We can select the first row from the group using PySpark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber() and partitionBy().


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

On the above example, it performs below steps.

  • first, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the PySpark filter(), just select row == 1, which returns just the first row of each group.
  • Finally, if a row column is not needed, just drop it.

+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|        Maria|   Finance|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

PySpark SQL expression to achieve the same result.


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary) as rn " +
     " FROM EMP) tmp where rn = 1").show()

3. Retrieve Employee who earns the highest salary

To retrieve the highest salary for each department, will use orderby “salary” in descending order and retrieve the first element.


w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

Outputs the following


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|      Michael|     Sales|  4600|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
+-------------+----------+------+

4. Select the Highest, Lowest, Average, and Total salary for each department group

Here, we will retrieve the Highest, Average, Total and Lowest salary for each group. Below snippet uses partitionBy and row_number along with aggregation functions avg, sum, min, and max.


from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

Outputs the following aggregated values for each group.


+----------+------+-----+----+----+
|department|   avg|  sum| min| max|
+----------+------+-----+----+----+
|     Sales|3900.0|11700|3000|4600|
|   Finance|3300.0|13200|3000|3900|
| Marketing|2500.0| 5000|2000|3000|
+----------+------+-----+----+----+

5. Complete Example of Select First Row of Each Group

Below is a complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
    ]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

# Select First Row of Group
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get highest salary of each group  
w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get max, min, avg, sum of each group
from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

6. Conclusion

In this article, you have learned how to retrieve the first row of each group in a PySpark Dataframe by using window functions and also learned how to get the max, min, average and total of each group with example

Happy Learning !!

Reference

This Post Has One Comment

  1. Hanan

    while window function work. I found that finding the latest timestamp and then use `left-semi` join with the original data works several of order of magnitude faster.

Comments are closed.