PySpark Select First Row of Each Group?

In PySpark select/find the first row of each group within a DataFrame can be get by grouping the data using window partitionBy() function and running row_number() function over window partition. let’s see with an example.

1. Prepare Data & DataFrame

Before we start let’s create the PySpark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

Outputs below table.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. PySpark Select First Row From every Group

We can select the first row from the group using PySpark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber() and partitionBy().


from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

On the above example, it performs below steps.

  • first, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the PySpark filter(), just select row == 1, which returns just the first row of each group.
  • Finally, if a row column is not needed, just drop it.

+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|        Maria|   Finance|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

PySpark SQL expression to achieve the same result.


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary) as rn " +
     " FROM EMP) tmp where rn = 1").show()

3. Retrieve Employee who earns the highest salary

To retrieve the highest salary for each department, will use orderby “salary” in descending order and retrieve the first element.


w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

Outputs the following


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|      Michael|     Sales|  4600|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
+-------------+----------+------+

4. Select the Highest, Lowest, Average, and Total salary for each department group

Here, we will retrieve the Highest, Average, Total and Lowest salary for each group. Below snippet uses partitionBy and row_number along with aggregation functions avg, sum, min, and max.


from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

Outputs the following aggregated values for each group.


+----------+------+-----+----+----+
|department|   avg|  sum| min| max|
+----------+------+-----+----+----+
|     Sales|3900.0|11700|3000|4600|
|   Finance|3300.0|13200|3000|3900|
| Marketing|2500.0| 5000|2000|3000|
+----------+------+-----+----+----+

5. Complete Example of PySpark Select First Row of Each Group

Below is complete example.


from pyspark.sql import SparkSession,Row
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

data = [("James","Sales",3000),("Michael","Sales",4600),
      ("Robert","Sales",4100),("Maria","Finance",3000),
      ("Raman","Finance",3000),("Scott","Finance",3300),
      ("Jen","Finance",3900),("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
    ]

df = spark.createDataFrame(data,["Name","Department","Salary"])
df.show()

# Select First Row of Group
from pyspark.sql.window import Window
from pyspark.sql.functions import col, row_number
w2 = Window.partitionBy("department").orderBy(col("salary"))
df.withColumn("row",row_number().over(w2)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get highest salary of each group  
w3 = Window.partitionBy("department").orderBy(col("salary").desc())
df.withColumn("row",row_number().over(w3)) \
  .filter(col("row") == 1).drop("row") \
  .show()

#Get max, min, avg, sum of each group
from pyspark.sql.functions import col, row_number,avg,sum,min,max
w4 = Window.partitionBy("department")
df.withColumn("row",row_number().over(w3)) \
  .withColumn("avg", avg(col("salary")).over(w4)) \
  .withColumn("sum", sum(col("salary")).over(w4)) \
  .withColumn("min", min(col("salary")).over(w4)) \
  .withColumn("max", max(col("salary")).over(w4)) \
  .where(col("row")==1).select("department","avg","sum","min","max") \
  .show()

6. Conclusion

In this article, you have learned how to retrieve the first row of each group in a PySpark Dataframe by using window functions and also learned how to get the max, min, average and total of each group with example

Happy Learning !!

Reference

NNK

SparkByExamples.com is a Big Data and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment Read more ..

Leave a Reply