Spark DataFrame Select First Row of Each Group?

In this Spark article, I’ve explained how to select/get the first row, min (minimum), max (maximum) of each group in DataFrame using Spark SQL window functions and Scala example. Though I’ve explained here with Scala, the same method could be used to working with PySpark and Python.

1. Preparing Data & DataFrame

Before, we start let’s create the DataFrame from a sequence of the data to work with. This DataFrame contains 3 columns “employee_name”, “department” and “salary” and column “department” contains different departments to do grouping.

Will use this Spark DataFrame to select the first row for each group, minimum salary for each group and maximum salary for the group. finally will also see how to get the sum and the average salary for each department group.


val simpleData = Seq(("James","Sales",3000),
      ("Michael","Sales",4600),
      ("Robert","Sales",4100),
      ("Maria","Finance",3000),
      ("Raman","Finance",3000),
      ("Scott","Finance",3300),
      ("Jen","Finance",3900),
      ("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
    )
import spark.implicits._
val df = simpleData.toDF("Name","Department","Salary")
df.show()

Outputs below table


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        Raman|   Finance|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

2. Select First Row From a Group

We can select the first row from the group using Spark SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber and partitionBy.


val w2 = Window.partitionBy("department").orderBy(col("salary"))
    df.withColumn("row",row_number.over(w2))
      .where($"row" === 1).drop("row")
      .show()

On above snippet, first, we are partitioning on department column which groups all same departments into a group and then apply order on salary column. Now, And will use this window with row_number function. This snippet outputs the following.

row_number function returns a sequential number starting from 1 within a window partition group.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|        Maria|   Finance|  3000|
|        Kumar| Marketing|  2000|
+-------------+----------+------+

3. Retrieve Employee who earns the highest salary

To retrieve the highest salary for each department, will use orderby “salary” in descending order and retrieve the first element.


val w3 = Window.partitionBy("department").orderBy(col("salary").desc)
    df.withColumn("row",row_number.over(w3))
      .where($"row" === 1).drop("row")
      .show()

Outputs the following


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|      Michael|     Sales|  4600|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
+-------------+----------+------+

4. Select the Highest, Lowest, Average, and Total salary for each department group

Here, we will retrieve the Highest, Average, Total and Lowest salary for each group. Below snippet uses partitionBy and row_number along with aggregation functions avg, sum, min, and max.


val w4 = Window.partitionBy("department")
    val aggDF = df.withColumn("row",row_number.over(w3))
      .withColumn("avg", avg(col("salary")).over(w4))
      .withColumn("sum", sum(col("salary")).over(w4))
      .withColumn("min", min(col("salary")).over(w4))
      .withColumn("max", max(col("salary")).over(w4))
      .where(col("row")===1).select("department","avg","sum","min","max")
      .show()

Outputs the following aggregated values for each group.


+----------+------+-----+----+----+
|department|   avg|  sum| min| max|
+----------+------+-----+----+----+
|     Sales|3900.0|11700|3000|4600|
|   Finance|3300.0|13200|3000|3900|
| Marketing|2500.0| 5000|2000|3000|
+----------+------+-----+----+----+

5. Complete Example of Select First Row of a Group


package com.sparkbyexamples.spark.dataframe.functions
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._


object WindowGroupbyFirst extends App {

    val spark: SparkSession = SparkSession.builder()
      .master("local[1]")
      .appName("SparkByExamples.com")
      .getOrCreate()

    spark.sparkContext.setLogLevel("ERROR")

    import spark.implicits._

    val simpleData = Seq(("James","Sales",3000),
      ("Michael","Sales",4600),
      ("Robert","Sales",4100),
      ("Maria","Finance",3000),
      ("Raman","Finance",3000),
      ("Scott","Finance",3300),
      ("Jen","Finance",3900),
      ("Jeff","Marketing",3000),
      ("Kumar","Marketing",2000)
    )
    val df = simpleData.toDF("employee_name","department","salary")
    df.show()

    //Get the first row from a group.
    val w2 = Window.partitionBy("department").orderBy(col("salary"))
    df.withColumn("row",row_number.over(w2))
      .where($"row" === 1).drop("row")
      .show()

     //Retrieve Highest salary
    val w3 = Window.partitionBy("department").orderBy(col("salary").desc)
    df.withColumn("row",row_number.over(w3))
      .where($"row" === 1).drop("row")
      .show()

    //Maximum, Minimum, Average, total salary for each window group
    val w4 = Window.partitionBy("department")
    val aggDF = df.withColumn("row",row_number.over(w3))
      .withColumn("avg", avg(col("salary")).over(w4))
      .withColumn("sum", sum(col("salary")).over(w4))
      .withColumn("min", min(col("salary")).over(w4))
      .withColumn("max", max(col("salary")).over(w4))
      .where(col("row")===1).select("department","avg","sum","min","max")
      .show()
}

6. Conclusion

In this article, you have learned how to retrieve the first row of each group, minimum, maximum, average and sum for each group in a Spark Dataframe.

Reference

Happy Learning !!

NNK

SparkByExamples.com is a Big Data and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment Read more ..

This Post Has 5 Comments

  1. Anonymous

    Hi team,

    Examples are simple and clear and it will very help full if you include more spark architecture details like DAG ,memory management,linear graph …as interviews they concentrate more on architecture side so it will help many

    1. NNK

      Thanks for reading how to select first row of each group. sure, I will add details like DAG, Linear graph in the future. thanks for suggestions.

  2. Anonymous

    Would it not be more efficient to do the above as I have shown below? I would be interested to hear
    (df.groupBy(‘department’).agg(
    F.max(‘salary’).alias(‘max’),
    F.avg(‘salary’).alias(‘avg’),
    F.min(‘salary’).alias(‘min’))
    ).show()

    This article is great btw, I thoroughly enjoy how things are explained with real life examples not just a small snapshot of an examples! Cheers

    1. NNK

      Thank you for wonderful words and providing your feedback. Certainly you can use groupBy() however, window functions performs better compared with groupBy, at this time I don’t have metrics to support this, however, soon I will write an article explaining differences using groupby and window functions.. stay tuned.

  3. reddy

    A spreadsheet consists of a two-dimensional array of cells, labeled A0, A1, etc. Rows are identified using letters, columns by numbers. Each cell contains either an integer (its value) or an expression. Expressions always start with a ‘=’ and can contain integers, cell references, operators ‘+’, ‘-‘, ‘*’, ‘/’ and parentheses ‘(‘, ‘)’ with the usual rules of evaluation.

    Write a program (in Java, Scala or Kotlin) to read the input from a file, evaluate the values of all the cells, and write the output to an output file.

    The input and output files should be in CSV format. For example, the following CSV input:
    2,4,1,=A0+A1*A2
    =A3*(A0+1),=B2,0,=A0+1

    should produce the following output file:

    2.00000,4.00000,1.00000,6.00000
    18.00000,0.00000,0.00000,3.00000

    The project should include unit tests, a build script (maven, gradle, sbt) and a README file describing how to build the artifacts.

    After the build process, the program should run with the following command:

    java –jar spreasheet.jar –i inputfile.csv –o outputfile.csv

    Notes
    • Your program should detect cyclic dependencies in the input data, report these in a sensible manner, and exit with a non-zero exit code.
    • All numbers in the input are positive integers, but internal calculations and output should be in double precision floating point.
    • You can assume that there are no more than 26 rows (A to Z). However, columns can be up to 5,000,000.
    • Additional points will be given for any speed optimization which utilizes multi-threading during the compute of the spreadsheet values

Leave a Reply