You are currently viewing Spark Select Max Row Per Group in DataFrame

In Spark, you can select the maximum (max) row per group in the DataFrame by using the row_number() window function to rank rows within each partition (group) based on the value column in descending order. Then, filter the DataFrame to keep only the rows with rank 1 representing the maximum values within each group.

Here is an example in Scala:

1. Prepare Data & DataFrame

First, let’s create the Spark DataFrame with 3 columns employee_namedepartment and salary. Column department contains different departments to do grouping.


// Import
import org.apache.spark.sql.SparkSession

// Create Spark Session
val spark = SparkSession.builder()
    .master("local[1]")
    .appName("SparkByExample")
    .getOrCreate();

// Create data
val data = Seq(("James","Sales",3000),("Michael","Sales",4600),
  ("Robert","Sales",4100),("Maria","Finance",3000),
  ("Raman","Finance",3000),("Scott","Finance",3300),
  ("Jen","Finance",3900),("Jeff","Marketing",3000),
  ("Kumar","Marketing",2000))

// Create DataFrame
val df = spark.createDataFrame(data).toDF("Name","Department","Salary")
df.show()

Yields below output.


// Output:
+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|  James|     Sales|  3000|
|Michael|     Sales|  4600|
| Robert|     Sales|  4100|
|  Maria|   Finance|  3000|
|  Raman|   Finance|  3000|
|  Scott|   Finance|  3300|
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
|  Kumar| Marketing|  2000|
+-------+----------+------+

From the above data frame, we would select the below-highlighted rows as these are the maximum salaries of each department.

Spark Select Max Row Per Group

2. Select Maximum Row per Group in DataFrame

This code snippet demonstrates how to select the rows with the maximum (max) value per group in a Spark DataFrame. It uses the row_number() window function to rank rows within each partition (group) based on the salary column in descending order.

Note that this code assumes you have a DataFrame df with columns named "department" and "salary", and you want to find the maximum value within each group in the "department" column. You may need to adjust column names and data types according to your specific DataFrame structure.


import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window

// Define a window specification to partition by the department column 
// and order by the salary column
val windowDept = Window.partitionBy("department").orderBy(desc("salary"))

// Use the window function to rank rows within each partition
val rankedDF = df.withColumn("row",row_number().over(windowDept))
rankedDF.show()

It adds a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.

The above example yields the below output.


// Output:
+-------+----------+------+---+
|   Name|Department|Salary|row|
+-------+----------+------+---+
|    Jen|   Finance|  3900|  1|
|  Scott|   Finance|  3300|  2|
|  Maria|   Finance|  3000|  3|
|  Raman|   Finance|  3000|  4|
|   Jeff| Marketing|  3000|  1|
|  Kumar| Marketing|  2000|  2|
|Michael|     Sales|  4600|  1|
| Robert|     Sales|  4100|  2|
|  James|     Sales|  3000|  3|
+-------+----------+------+---+

Finally, filter the DataFrame to keep only the rows with rank 1 representing the maximum values within each group.


// Filter the DataFrame to keep only the rows with rank 1 
// (i.e., the maximum salary within each group)
rankedDF.filter(col("row") === 1).drop("row")
    .show()

Yields below output.


// Output:
+-------+----------+------+
|   Name|Department|Salary|
+-------+----------+------+
|    Jen|   Finance|  3900|
|   Jeff| Marketing|  3000|
|Michael|     Sales|  4600|
+-------+----------+------+

3. SQL expression to Find the Maximum Row per Group

You can also get the maximum (max) row for each group using Spark SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView().

In Apache Spark, df.createOrReplaceTempView("EMP") creates a temporary view of a DataFrame (df) in the Spark Session’s catalog. This temporary view, named "EMP" in this case, allows you to run SQL queries against the DataFrame.


// Spark SQL Example
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
     " (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
     " FROM EMP) tmp where rn = 1").show()

The lifetime of this view is to your current SparkSession, if you want to drop this view, use spark.catalog.dropTempView("tempViewName").

Here’s what each part does:

  • df: This is a DataFrame you’ve previously defined or obtained through some operations in your Spark application.
  • createOrReplaceTempView: This function creates a temporary view in the Spark catalog or replaces an existing view if it already exists. The view is temporary, meaning it will exist only for the duration of the Spark Session.
  • "EMP": This is the name given to the temporary view. In this case, the view is named "EMP".
  • spark.sql(): This is used to execute the SQL statements.

Conclusion

In summary, we have done the following to select the maximum (max) value of each row group in Spark DataFrame.

  • First, Partition the DataFrame on department column, which groups all same departments into a group.
  • Apply orderBy() on salary column by descending order.
  • Add a new column row by running row_number() function over the partition window. row_number() function returns a sequential number starting from 1 within a window partition group.
  • Using the Spark filter(), just select row == 1, which returns the maximum salary of each group.
  • Finally, if a row column is not needed, just drop it.

Happy Learning !!

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium