In Spark, find/select maximum (max) row per group can be calculated using window partitionBy() function and running row_number() function over window partition, let’s see with a DataFrame example.
1. Prepare Data & DataFrame
First, let’s Create Spark DataFrame with 3 columns employee_name
, department
and salary
. Column department
contains different departments to do grouping.
// Prepare Data & DataFrame
val simpleData = Seq(("James","Sales",3000),
("Michael","Sales",4600),("Robert","Sales",4100),
("Maria","Finance",3000),("Raman","Finance",3000),
("Scott","Finance",3300),("Jen","Finance",3900),
("Jeff","Marketing",3000),("Kumar","Marketing",2000)
)
import spark.implicits._
val df = simpleData.toDF("employee_name","department","salary")
df.show()
yields below output.
// Output:
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| James| Sales| 3000|
| Michael| Sales| 4600|
| Robert| Sales| 4100|
| Maria| Finance| 3000|
| Raman| Finance| 3000|
| Scott| Finance| 3300|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
| Kumar| Marketing| 2000|
+-------------+----------+------+
2. Spark Find Maximum Row per Group in DataFrame
We can select/find the maximum row per group using Spark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber()
, partitionBy()
and orderBy()
.
This example calculates highest salary
of each department
group.
// Spark Find Maximum Row per Group in DataFrame
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
val windowDept = Window.partitionBy("department").orderBy(col("salary").desc)
df.withColumn("row",row_number.over(windowDept ))
.where($"row" === 1).drop("row")
.show()
// Displays
+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
| Michael| Sales| 4600|
| Jen| Finance| 3900|
| Jeff| Marketing| 3000|
+-------------+----------+------+
On the above example, it performs below steps.
- first, Partition the DataFrame on
department
column, which groups all same departments into a group. - Apply
orderBy()
onsalary
column by descending order. - Add a new column
row
by runningrow_number()
function over the partition window.row_number()
function returns a sequential number starting from 1 within a window partition group. - Using the Spark filter(), just select row == 1, which returns the maximum salary of each group.
- Finally, if a
row
column is not needed, just drop it.
3. Spark SQL expression to Find Maximum Row per Group
You can also get the maximum record for each group using SQL expression. If you have an SQL background, this would be much familiar to you. In order to use SQL, first, you need to create a temporary view using createOrReplaceTempView()
.
df.createOrReplaceTempView("EMP")
spark.sql("select Name, Department, Salary from "+
" (select *, row_number() OVER (PARTITION BY department ORDER BY salary DESC) as rn " +
" FROM EMP) tmp where rn = 1").show()
This yields the same output as above. Note that this temporary view creates or replaces a local temporary view with this DataFrame df. The lifetime of this view is to your current SparkSession
, if you wanted to drop this view use spark.catalog.dropTempView("tempViewName")
4. Complete Example
Following is complete example.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
object WindowGroupbyMax extends App {
val spark: SparkSession = SparkSession.builder()
.master("local[1]")
.appName("SparkByExamples.com")
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
val simpleData = Seq(("James","Sales",3000),
("Michael","Sales",4600),("Robert","Sales",4100),
("Maria","Finance",3000),("Raman","Finance",3000),
("Scott","Finance",3300),("Jen","Finance",3900),
("Jeff","Marketing",3000),("Kumar","Marketing",2000)
)
import spark.implicits._
val df = simpleData.toDF("employee_name","department","salary")
df.show()
val w2 = Window.partitionBy("department").orderBy(col("salary").desc)
df.withColumn("row",row_number.over(w2))
.where($"row" === 1).drop("row")
.show()
}
5. Conclusion
In summary you an find maximum (max) row for each group by partitioning the data by group using window partitionBy(), sort the partition date per each group, add row_number() to the sorted data and finally filter to get the first record.
Happy Learning !!
Related Articles
- Spark Groupby Example with DataFrame
- Spark DataFrame Select First Row of Each Group?
- Spark Internal Execution plan
- What is Apache Spark Driver?
- What is Spark Streaming Checkpoint?
- Spark Convert a Row into Case Class
- Difference in DENSE_RANK and ROW_NUMBER in Spark
- Spark DataFrame – Fetch More Than 20 Rows & Column Full Value