Find Maximum Row per Group in Spark DataFrame

In Spark, find/select maximum (max) row per group can be calculated using window partitionBy() function and running row_number() function over window partition, let’s see with a DataFrame example.

1. Prepare Data & DataFrame

First, let’s Create Spark DataFrame with 3 columns employee_name, department and salary. Column department contains different departments to do grouping.


// Prepare Data & DataFrame
val simpleData = Seq(("James","Sales",3000),
   ("Michael","Sales",4600),("Robert","Sales",4100),
   ("Maria","Finance",3000),("Raman","Finance",3000),
   ("Scott","Finance",3300),("Jen","Finance",3900),
   ("Jeff","Marketing",3000),("Kumar","Marketing",2000)
)
import spark.implicits._
val df = simpleData.toDF("employee_name","department","salary")
df.show()

yields below output.


// Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. Spark Find Maximum Row per Group in DataFrame

We can select/find the maximum row per group using Spark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber(), partitionBy() and orderBy().

This example calculates highest salary of each department group.


// Spark Find Maximum Row per Group in DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val windowDept = Window.partitionBy("department").orderBy(col("salary").desc)
df.withColumn("row",row_number.over(windowDept ))
  .where($"row" === 1).drop("row")
  .show()

// Displays
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|      Michael|     Sales|  4600|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
+-------------+----------+------+

On the above example, it performs below steps.

  • first, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column by descending order.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the Spark filter(), just select row == 1, which returns the maximum salary of each group.
  • Finally, if a row column is not needed, just drop it.

3. Spark SQL expression to Find Maximum Row per Group

You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be much familiar to you. In order to use SQL, first, you need to create a temporary view using createOrReplaceTempView().


df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

This yields the same output as above. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession, if you wanted to drop this view use spark.catalog.dropTempView("tempViewName")

4. Complete Example

Following is complete example.


import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

object WindowGroupbyMax extends App {

    val spark: SparkSession = SparkSession.builder()
      .master("local[1]")
      .appName("SparkByExamples.com")
      .getOrCreate()

    spark.sparkContext.setLogLevel("ERROR")

    val simpleData = Seq(("James","Sales",3000),
      ("Michael","Sales",4600),("Robert","Sales",4100),
      ("Maria","Finance",3000),("Raman","Finance",3000),
      ("Scott","Finance",3300),("Jen","Finance",3900),
      ("Jeff","Marketing",3000),("Kumar","Marketing",2000)
    )
    import spark.implicits._
    val df = simpleData.toDF("employee_name","department","salary")
    df.show()

    val w2 = Window.partitionBy("department").orderBy(col("salary").desc)
    df.withColumn("row",row_number.over(w2))
      .where($"row" === 1).drop("row")
      .show()
}

5. Conclusion

In summary you an find maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition date per each group, add row_number() to the sorted data and finally filter to get the first record.

Happy Learning !!

References

Naveen (NNK)

I am Naveen (NNK) working as a Principal Engineer. I am a seasoned Apache Spark Engineer with a passion for harnessing the power of big data and distributed computing to drive innovation and deliver data-driven insights. I love to design, optimize, and managing Apache Spark-based solutions that transform raw data into actionable intelligence. I am also passion about sharing my knowledge in Apache Spark, Hive, PySpark, R etc.

Leave a Reply

You are currently viewing Find Maximum Row per Group in Spark DataFrame