• Post author:
  • Post category:PySpark
  • Post last modified:March 27, 2024
  • Reading time:9 mins read
You are currently viewing PySpark Select First Row of Each Group?

In PySpark select/find the first row of each group within a DataFrame can be get by grouping the data using window partitionBy() function and running row_number() function over window partition. let’s see with an example.

Advertisements

1. Prepare Data & DataFrame

Before we start let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Outputs below table.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. PySpark Select First Row From every Group

We can select the first row from the group using PySpark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber() and partitionBy().


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

On the above example, it performs below steps.

  • first, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the PySpark filter(), just select row == 1, which returns just the first row of each group.
  • Finally, if a row column is not needed, just drop it.

+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|        Maria|   Finance|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

PySpark SQL expression to achieve the same result.


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary) as rn " +
     " FROM EMP) tmp where rn = 1").show()

3. Retrieve Employee who earns the highest salary

To retrieve the highest salary for each department, will use orderby “salary” in descending order and retrieve the first element.


w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

Outputs the following


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|      Michael|     Sales|  4600|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
+-------------+----------+------+

4. Select the Highest, Lowest, Average, and Total salary for each department group

Here, we will retrieve the Highest, Average, Total and Lowest salary for each group. Below snippet uses partitionBy and row_number along with aggregation functions avg, sum, min, and max.


from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

Outputs the following aggregated values for each group.


+----------+------+-----+----+----+
|department|   avg|  sum| min| max|
+----------+------+-----+----+----+
|     Sales|3900.0|11700|3000|4600|
|   Finance|3300.0|13200|3000|3900|
| Marketing|2500.0| 5000|2000|3000|
+----------+------+-----+----+----+

5. Complete Example of PySpark Select First Row of Each Group

Below is complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
    ]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

# Select First Row of Group
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get highest salary of each group  
w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get max, min, avg, sum of each group
from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

6. Conclusion

In this article, you have learned how to retrieve the first row of each group in a PySpark Dataframe by using window functions and also learned how to get the max, min, average and total of each group with example

Happy Learning !!

Reference

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium

This Post Has One Comment

  1. Hanan

    while window function work. I found that finding the latest timestamp and then use `left-semi` join with the original data works several of order of magnitude faster.

Comments are closed.