In Spark, you can select the maximum (max) row per group in the DataFrame by using the row_number()
window function to rank rows within each partition (group) based on the value
column in descending order. Then, filter the DataFrame to keep only the rows with rank 1 representing the maximum values within each group.
Here is an example in Scala:
1. Prepare Data & DataFrame
First, let’s create the Spark DataFrame with 3 columns employee_name
, department
and salary
. Column department
contains different departments to do grouping.
// Import
import org.apache.spark.sql.SparkSession
// Create Spark Session
val spark = SparkSession.builder()
.master("local[1]")
.appName("SparkByExample")
.getOrCreate();
// Create data
val data = Seq(("James","Sales",3000),("Michael","Sales",4600),
("Robert","Sales",4100),("Maria","Finance",3000),
("Raman","Finance",3000),("Scott","Finance",3300),
("Jen","Finance",3900),("Jeff","Marketing",3000),
("Kumar","Marketing",2000))
// Create DataFrame
val df = spark.createDataFrame(data).toDF("Name","Department","Salary")
df.show()
Yields below output.
// Output:
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
| James| Sales| 3000|
|Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------+----------+------+
From the above data frame, we would select the below-highlighted rows as these are the maximum salaries of each department.
2. Select Maximum Row per Group in DataFrame
This code snippet demonstrates how to select the rows with the maximum (max) value per group in a Spark DataFrame. It uses the row_number()
window function to rank rows within each partition (group) based on the salary
column in descending order.
Note that this code assumes you have a DataFrame df
with columns named "department"
and "salary"
, and you want to find the maximum value within each group in the "department"
column. You may need to adjust column names and data types according to your specific DataFrame structure.
import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions.Window
// Define a window specification to partition by the department column
// and order by the salary column
val windowDept = Window.partitionBy("department").orderBy(desc("salary"))
// Use the window function to rank rows within each partition
val rankedDF = df.withColumn("row",row_number().over(windowDept))
rankedDF.show()
It adds a new column row
by running row_number()
function over the partition window. row_number()
function returns a sequential number starting from 1 within a window partition group.
The above example yields the below output.
// Output:
+-------+----------+------+---+
| Name|Department|Salary|row|
+-------+----------+------+---+
| Jen| Finance| 3900| 1|
| Scott| Finance| 3300| 2|
| Maria| Finance| 3000| 3|
| Raman| Finance| 3000| 4|
| Jeff| Marketing| 3000| 1|
| Kumar| Marketing| 2000| 2|
|Michael| Sales| 4600| 1|
| Robert| Sales| 4100| 2|
| James| Sales| 3000| 3|
+-------+----------+------+---+
Finally, filter the DataFrame to keep only the rows with rank 1 representing the maximum values within each group.
// Filter the DataFrame to keep only the rows with rank 1
// (i.e., the maximum salary within each group)
rankedDF.filter(col("row") === 1).drop("row")
.show()
Yields below output.
// Output:
+-------+----------+------+
| Name|Department|Salary|
+-------+----------+------+
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
|Michael| Sales| 4600|
+-------+----------+------+
3. SQL expression to Find the Maximum Row per Group
You can also get the maximum (max) row for each group using Spark SQL expression. If you have an SQL background, this would be familiar to you. To use SQL, first, you need to create a temporary view using createOrReplaceTempView()
.
In Apache Spark, df.createOrReplaceTempView("EMP")
creates a temporary view of a DataFrame (df
) in the Spark Session’s catalog. This temporary view, named "EMP"
in this case, allows you to run SQL queries against the DataFrame.
// Spark SQL Example
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
The lifetime of this view is to your current SparkSession, if you want to drop this view, use spark.catalog.dropTempView("tempViewName")
.
Here’s what each part does:
df
: This is a DataFrame you’ve previously defined or obtained through some operations in your Spark application.createOrReplaceTempView
: This function creates a temporary view in the Spark catalog or replaces an existing view if it already exists. The view is temporary, meaning it will exist only for the duration of the Spark Session."EMP"
: This is the name given to the temporary view. In this case, the view is named"EMP"
.spark.sql()
: This is used to execute the SQL statements.
Conclusion
In summary, we have done the following to select the maximum (max) value of each row group in Spark DataFrame.
- First, Partition the DataFrame on
department
column, which groups all same departments into a group. - Apply
orderBy()
onsalary
column by descending order. - Add a new column
row
by runningrow_number()
function over the partition window.row_number()
function returns a sequential number starting from 1 within a window partition group. - Using the Spark filter(), just select row == 1, which returns the maximum salary of each group.
- Finally, if a
row
column is not needed, just drop it.
Happy Learning !!