PySpark Groupby Agg (aggregate) – Explained

  • Post author:
  • Post category:PySpark
  • Post last modified:August 13, 2022

PySpark Groupby Agg is used to calculate more than one aggregate (multiple aggregates) at a time on grouped DataFrame. So to perform the agg, first, you need to perform the groupBy() on DataFrame which groups the records based on single or multiple column values, and then do the agg() to get the aggregate for each group.

In this article, I will explain how to use agg() function on grouped DataFrame with examples. PySpark groupBy() function is used to collect the identical data into groups and use agg() function to perform count, sum, avg, min, max e.t.c aggregations on the grouped data.

1. Quick Examples of Groupby Agg

Following are quick examples of how to perform groupBy() and agg() (aggregate).


# Quick Examples

# Imports
from pyspark.sql.functions import sum,avg,max,count

# Example 1 - PySpark groupby agg
df.groupBy("department") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

# Example 3 - Multiple Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") ) \
    .show(truncate=False)

# Example 4 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()

Before we start running these examples, let’s create the DataFrame from a sequence of the data to work with. This DataFrame contains columns “employee_name”, “department”, “state“, “salary”, “age”, and “bonus” columns.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') 
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

Yields below output.

pyspark groupby agg

2. PySpark Groupby Aggregate Example

By using DataFrame.groupBy().agg() in PySpark you can get the number of rows for each group by using count aggregate function. DataFrame.groupBy() function returns a pyspark.sql.GroupedData object which contains a agg() method to perform aggregate on a grouped DataFrame.

After performing aggregates this function returns a PySpark DataFrame.


# Syntax
GroupedData.agg(*exprs)

To use aggregate functions like sum(), avg(), min(), max() e.t.c you have to import from pyspark.sql.functions. In the below example I am calculating the number of rows for each group by grouping on the department column and using agg() and count() function.


# PySpark groupBy() agg
from pyspark.sql.functions import count

df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

Yields below output.

pyspark groupby aggregate

3. Groupby Agg on Multiple Columns

Groupby Aggregate on Multiple Columns in PySpark can be performed by passing two or more columns to the groupBy() function and using the agg(). The following example performs grouping on department and state columns and on the result, I have used the count() function within agg().


# groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

Yields below output.

groupby agg multiple columns

4. Running Multiple Aggregates at a time

Using groupBy() and agg() aggregate function we can calculate multiple aggregate at a time on a single statement using PySpark SQL aggregate functions sum(), avg(), min(), max() mean(), count() e.t.c. In order to use these, we should import "from pyspark.sql.functions import sum,avg,max,min,mean,count"


df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

This example does group on department column and calculates sum() and avg() of salary for each department and calculates sum() and max() of bonus for each department.

pyspark groupby agg count

5. Using Where on Aggregate DataFrame

Similar to SQL “HAVING” clause, On PySpark DataFrame we can use either where() or filter() function to filter the rows on top of aggregate data.


# Using groupBy(), agg() and where()
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

Yields below output.

6. PySpark SQL GROUP BY & HAVING

Finally, let’s convert the above groupBy() agg() into PySpark SQL query and execute it. In order to do so, first, you need to create a temporary view by using createOrReplaceTempView() and use SparkSession.sql() to run the query. The table would be available to use until you end your SparkSession.


# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"

# Execute SQL
spark.sql(sql_str).show()

Yields the same output as above.

7. Complete Example of Groupby Agg (Aggregate)

Following is a complete example of the groupBy() and agg(). This example is also available at GitHub PySpark Examples project for reference.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.printSchema()
df.show(truncate=False)

from pyspark.sql.functions import sum,avg,max,count

# Example 1 - PySpark groupby agg
df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)
    
# Example 3 - Multiple Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

# Example 4 - Using where on Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

# Example 5 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()

8. Conclusion

PySpark DataFrame.groupBy().agg() is used to get the aggregate values like count, sum, avg, min, max for each group. You can also get aggregates per group by using PySpark SQL, in order to use SQL, first you need to create a temporary view.

Related Articles

References

NNK

SparkByExamples.com is a Big Data and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment Read more ..

Leave a Reply

You are currently viewing PySpark Groupby Agg (aggregate) – Explained