Spark SQL Aggregate Functions

Spark SQL provides built-in standard Aggregate functions defines in DataFrame API, these come in handy when we need to make aggregate operations on DataFrame columns. Aggregate functions operate on a group of rows and calculate a single return value for every group.

All these aggregate functions accept input as, Column type or column name in a string and several other arguments based on the function and return Column type.

When possible try to leverage standard library as they are little bit more compile-time safety, handles null and perform better when compared to UDF’s. If your application is critical on performance try to avoid using custom UDF at all costs as these are not guarantee on performance.

Spark Aggregate Functions

Spark SQL Aggregate functions are grouped as “agg_funcs” in spark SQL. Below is a list of functions defined under this group. Click on each link to learn with a Scala example.

Note that each and every below function has another signature which takes String as a column name instead of Column.

AGGREGATE FUNCTION SYNTAXAGGREGATE FUNCTION DESCRIPTION
approx_count_distinct(e: Column)Returns the count of distinct items in a group.
approx_count_distinct(e: Column, rsd: Double)Returns the count of distinct items in a group.
avg(e: Column)Returns the average of values in the input column.
collect_list(e: Column)Returns all values from an input column with duplicates.
collect_set(e: Column)Returns all values from an input column with duplicate values .eliminated.
corr(column1: Column, column2: Column)Returns the Pearson Correlation Coefficient for two columns.
count(e: Column)Returns number of elements in a column.
countDistinct(expr: Column, exprs: Column*)Returns number of distinct elements in the columns.
covar_pop(column1: Column, column2: Column)Returns the population covariance for two columns.
covar_samp(column1: Column, column2: Column)Returns the sample covariance for two columns.
first(e: Column, ignoreNulls: Boolean)Returns the first element in a column when ignoreNulls is set to true, it returns first non null element.
first(e: Column): ColumnReturns the first element in a column.
grouping(e: Column)Indicates whether a specified column in a GROUP BY list is aggregated or not, returns 1 for aggregated or 0 for not aggregated in the result set.
kurtosis(e: Column)Returns the kurtosis of the values in a group.
last(e: Column, ignoreNulls: Boolean)Returns the last element in a column. when ignoreNulls is set to true, it returns last non null element.
last(e: Column)Returns the last element in a column.
max(e: Column)Returns the maximum value in a column.
mean(e: Column)Alias for Avg. Returns the average of the values in a column.
min(e: Column)Returns the minimum value in a column.
skewness(e: Column)Returns the skewness of the values in a group.
stddev(e: Column)alias for `stddev_samp`.
stddev_samp(e: Column)Returns the sample standard deviation of values in a column.
stddev_pop(e: Column)Returns the population standard deviation of the values in a column.
sum(e: Column)Returns the sum of all values in a column.
sumDistinct(e: Column)Returns the sum of all distinct values in a column.
variance(e: Column)alias for `var_samp`.
var_samp(e: Column)Returns the unbiased variance of the values in a column.
var_pop(e: Column)returns the population variance of the values in a column.

Aggregate Functions Examples

First, let’s create a DataFrame to work with aggregate functions. All example provided here is also available at GitHub project.


 import spark.implicits._

  val simpleData = Seq(("James", "Sales", 3000),
    ("Michael", "Sales", 4600),
    ("Robert", "Sales", 4100),
    ("Maria", "Finance", 3000),
    ("James", "Sales", 3000),
    ("Scott", "Finance", 3300),
    ("Jen", "Finance", 3900),
    ("Jeff", "Marketing", 3000),
    ("Kumar", "Marketing", 2000),
    ("Saif", "Sales", 4100)
  )
  val df = simpleData.toDF("employee_name", "department", "salary")
  df.show()

Yields below output.


+-------------+----------+------+
|employee_name|department|salary|
+-------------+----------+------+
|        James|     Sales|  3000|
|      Michael|     Sales|  4600|
|       Robert|     Sales|  4100|
|        Maria|   Finance|  3000|
|        James|     Sales|  3000|
|        Scott|   Finance|  3300|
|          Jen|   Finance|  3900|
|         Jeff| Marketing|  3000|
|        Kumar| Marketing|  2000|
|         Saif|     Sales|  4100|
+-------------+----------+------+

approx_count_distinct Aggregate Function

approx_count_distinct() function returns the count of distinct items in a group.


  //approx_count_distinct()
  println("approx_count_distinct: "+
    df.select(approx_count_distinct("salary")).collect()(0)(0))

//Prints approx_count_distinct: 6

avg (average) Aggregate Function

avg() function returns the average of values in the input column.


  //avg
  println("avg: "+
    df.select(avg("salary")).collect()(0)(0))

//Prints avg: 3400.0

collect_list Aggregate Function

collect_list() function returns all values from an input column with duplicates.


  //collect_list
  df.select(collect_list("salary")).show(false)

+------------------------------------------------------------+
|collect_list(salary)                                        |
+------------------------------------------------------------+
|[3000, 4600, 4100, 3000, 3000, 3300, 3900, 3000, 2000, 4100]|
+------------------------------------------------------------+

collect_set Aggregate Function

collect_set() function returns all values from an input column with duplicate values eliminated.


  //collect_set
  df.select(collect_set("salary")).show(false)

+------------------------------------+
|collect_set(salary)                 |
+------------------------------------+
|[4600, 3000, 3900, 4100, 3300, 2000]|
+------------------------------------+

countDistinct Aggregate Function

countDistinct() function returns the number of distinct elements in a columns


  //countDistinct
  val df2 = df.select(countDistinct("department", "salary"))
  df2.show(false)
  println("Distinct Count of Department & Salary: "+df2.collect()(0)(0))

count function()

count() function returns number of elements in a column.


  println("count: "+
    df.select(count("salary")).collect()(0))

Prints county: 10

grouping function()

grouping() Indicates whether a given input column is aggregated or not. returns 1 for aggregated or 0 for not aggregated in the result. If you try grouping directly on the salary column you will get below error.


Exception in thread "main" org.apache.spark.sql.AnalysisException:
  // grouping() can only be used with GroupingSets/Cube/Rollup

first function()

first() function returns the first element in a column when ignoreNulls is set to true, it returns the first non-null element.


//first
  df.select(first("salary")).show(false)

+--------------------+
|first(salary, false)|
+--------------------+
|3000                |
+--------------------+

last()

last() function returns the last element in a column. when ignoreNulls is set to true, it returns the last non-null element.


//last
  df.select(last("salary")).show(false)

+-------------------+
|last(salary, false)|
+-------------------+
|4100               |
+-------------------+

kurtosis()

kurtosis() function returns the kurtosis of the values in a group.


  df.select(kurtosis("salary")).show(false)

+-------------------+
|kurtosis(salary)   |
+-------------------+
|-0.6467803030303032|
+-------------------+

max()

max() function returns the maximum value in a column.


  df.select(max("salary")).show(false)

+-----------+
|max(salary)|
+-----------+
|4600       |
+-----------+

min()

min() function


df.select(min("salary")).show(false)

+-----------+
|min(salary)|
+-----------+
|2000       |
+-----------+

mean()

mean() function returns the average of the values in a column. Alias for Avg


df.select(mean("salary")).show(false)

+-----------+
|avg(salary)|
+-----------+
|3400.0     |
+-----------+

skewness()

skewness() function returns the skewness of the values in a group.


df.select(skewness("salary")).show(false)

+--------------------+
|skewness(salary)    |
+--------------------+
|-0.12041791181069571|
+--------------------+

stddev(), stddev_samp() and stddev_pop()

stddev() alias for stddev_samp.

stddev_samp() function returns the sample standard deviation of values in a column.

stddev_pop() function returns the population standard deviation of the values in a column.


  df.select(stddev("salary"), stddev_samp("salary"), 
    stddev_pop("salary")).show(false)

+-------------------+-------------------+------------------+
|stddev_samp(salary)|stddev_samp(salary)|stddev_pop(salary)|
+-------------------+-------------------+------------------+
|765.9416862050705  |765.9416862050705  |726.636084983398  |
+-------------------+-------------------+------------------+

sum()

sum() function Returns the sum of all values in a column.


df.select(sum("salary")).show(false)

+-----------+
|sum(salary)|
+-----------+
|34000      |
+-----------+

sumDistinct()

sumDistinct() function returns the sum of all distinct values in a column.


df.select(sumDistinct("salary")).show(false)

+--------------------+
|sum(DISTINCT salary)|
+--------------------+
|20900               |
+--------------------+

variance(), var_samp(), var_pop()

variance() alias for var_samp

var_samp() function returns the unbiased variance of the values in a column.

var_pop() function returns the population variance of the values in a column.


df.select(variance("salary"),var_samp("salary"),var_pop("salary"))
  .show(false)

+-----------------+-----------------+---------------+
|var_samp(salary) |var_samp(salary) |var_pop(salary)|
+-----------------+-----------------+---------------+
|586666.6666666666|586666.6666666666|528000.0       |
+-----------------+-----------------+---------------+

Source code of Spark SQL Aggregate Functions examples


package com.sparkbyexamples.spark.dataframe.functions.aggregate

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

object AggregateFunctions extends App {

  val spark: SparkSession = SparkSession.builder()
    .master("local[1]")
    .appName("SparkByExamples.com")
    .getOrCreate()

  spark.sparkContext.setLogLevel("ERROR")

  import spark.implicits._

  val simpleData = Seq(("James", "Sales", 3000),
    ("Michael", "Sales", 4600),
    ("Robert", "Sales", 4100),
    ("Maria", "Finance", 3000),
    ("James", "Sales", 3000),
    ("Scott", "Finance", 3300),
    ("Jen", "Finance", 3900),
    ("Jeff", "Marketing", 3000),
    ("Kumar", "Marketing", 2000),
    ("Saif", "Sales", 4100)
  )
  val df = simpleData.toDF("employee_name", "department", "salary")
  df.show()

  //approx_count_distinct()
  println("approx_count_distinct: "+
    df.select(approx_count_distinct("salary")).collect()(0)(0))

  //avg
  println("avg: "+
    df.select(avg("salary")).collect()(0)(0))

  //collect_list

  df.select(collect_list("salary")).show(false)

  //collect_set

  df.select(collect_set("salary")).show(false)

  //countDistinct
  val df2 = df.select(countDistinct("department", "salary"))
  df2.show(false)
  println("Distinct Count of Department & Salary: "+df2.collect()(0)(0))

  println("count: "+
    df.select(count("salary")).collect()(0))

  //first
  df.select(first("salary")).show(false)

  //last
  df.select(last("salary")).show(false)

  //Exception in thread "main" org.apache.spark.sql.AnalysisException: 
  // grouping() can only be used with GroupingSets/Cube/Rollup;
  //df.select(grouping("salary")).show(false)

  df.select(kurtosis("salary")).show(false)

  df.select(max("salary")).show(false)

  df.select(min("salary")).show(false)

  df.select(mean("salary")).show(false)

  df.select(skewness("salary")).show(false)

  df.select(stddev("salary"), stddev_samp("salary"),
    stddev_pop("salary")).show(false)

  df.select(sum("salary")).show(false)

  df.select(sumDistinct("salary")).show(false)

  df.select(variance("salary"),var_samp("salary"),
    var_pop("salary")).show(false)
}

Conclusion

In this article, I’ve consolidated and listed all Spark SQL Aggregate functions with scala examples and also learned the benefits of using Spark SQL functions.

Happy Learning !!

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium

Leave a Reply

This Post Has 6 Comments

  1. Elisa

    I’ve a question about grouping in SQL.
    if I would like to calculate min (or max) in a row comparing different colums (same format, i.e. DATE or number…), which function sould I use? MAX doesn’t work…. is there something?
    thanks!

  2. Elisa

    I’ve a question about grouping in SQL.
    if I would like to calculate min (or max) in a row comparing different colums (same format, i.e. DATE or number…), which function sould I use? MAX doesn’t work…. is there something?
    thanks!

  3. Priya

    All the concepts are very well explained with examples. I learned a lot of things from this website.
    The aggregate functions are demonstrated nicely….

    I have a query in aggregate examples,what is the difference between collect() and collect()(0)(0).

    Thanks in Advance.

    1. NNK

      Thanks Priya for your kind words
      collect() => returns an Array[T] from DataFrame which contains all rows.
      collect()(0)(0) => collect() return an array and (0) returns first record in an array and last (0) returns first column from a record.
      You can also write this as below
      val arr = collect()
      val row = arr(0)
      val value = row(0)

      Hope it helps !!

      Thanks

  4. Priya

    All the concepts are very well explained with examples. I learned a lot of things from this website.
    The aggregate functions are demonstrated nicely….

    I have a query in aggregate examples,what is the difference between collect() and collect()(0)(0).

    Thanks in Advance.

    1. NNK

      Thanks Priya for your kind words
      collect() => returns an Array[T] from DataFrame which contains all rows.
      collect()(0)(0) => collect() return an array and (0) returns first record in an array and last (0) returns first column from a record.
      You can also write this as below
      val arr = collect()
      val row = arr(0)
      val value = row(0)

      Hope it helps !!

      Thanks