• Post author:
  • Post category:PySpark
  • Post last modified:March 27, 2024
  • Reading time:18 mins read
You are currently viewing PySpark Groupby Agg (aggregate) – Explained

PySpark Groupby Agg is used to calculate more than one aggregate (multiple aggregates) at a time on grouped DataFrame. So to perform the agg, first, you need to perform the groupBy() on DataFrame which groups the records based on single or multiple column values, and then do the agg() to get the aggregate for each group.

In this article, I will explain how to use agg() function on grouped DataFrame with examples. PySpark groupBy() function is used to collect the identical data into groups and use agg() function to perform count, sum, avg, min, max e.t.c aggregations on the grouped data.

1. Quick Examples of Groupby Agg

Following are quick examples of how to perform groupBy() and agg() (aggregate).


# Quick Examples

# Imports
from pyspark.sql.functions import sum,avg,max,count

# Example 1 - PySpark groupby agg
df.groupBy("department") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

# Example 3 - Multiple Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") ) \
    .show(truncate=False)

# Example 4 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()

Before we start running these examples, let’s create the DataFrame from a sequence of the data to work with. This DataFrame contains columns “employee_name”, “department”, “state“, “salary”, “age”, and “bonus” columns.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') 
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

Yields below output.

pyspark groupby agg

2. PySpark Groupby Aggregate Example

By using DataFrame.groupBy().agg() in PySpark you can get the number of rows for each group by using count aggregate function. DataFrame.groupBy() function returns a pyspark.sql.GroupedData object which contains a agg() method to perform aggregate on a grouped DataFrame.

After performing aggregates this function returns a PySpark DataFrame.


# Syntax
GroupedData.agg(*exprs)

To use aggregate functions like sum(), avg(), min(), max() e.t.c you have to import from pyspark.sql.functions. In the below example I am calculating the number of rows for each group by grouping on the department column and using agg() and count() function.


# PySpark groupBy() agg
from pyspark.sql.functions import count

df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

Yields below output.

pyspark groupby aggregate

3. Groupby Agg on Multiple Columns

Groupby Aggregate on Multiple Columns in PySpark can be performed by passing two or more columns to the groupBy() function and using the agg(). The following example performs grouping on department and state columns and on the result, I have used the count() function within agg().


# groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

Yields below output.

groupby agg multiple columns

4. Running Multiple Aggregates at a time

Using groupBy() and agg() the aggregate function we can calculate multiple aggregates at a time on a single statement using PySpark SQL aggregate functions sum(), avg(), min(), max() mean(), count(), etc. In order to use these, we should import "from pyspark.sql.functions import sum,avg,max,min,mean,count"


df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

This example does group on department column and calculates sum() and avg() of salary for each department and calculates sum() and max() of bonus for each department.

pyspark groupby agg count

5. Using Where on Aggregate DataFrame

Similar to SQL “HAVING” clause, On PySpark DataFrame we can use either where() or filter() function to filter the rows on top of aggregate data.


# Using groupBy(), agg() and where()
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

Yields below output.

6. PySpark SQL GROUP BY & HAVING

Finally, let’s convert the above groupBy() agg() into PySpark SQL query and execute it. In order to do so, first, you need to create a temporary view by using createOrReplaceTempView() and use SparkSession.sql() to run the query. The table would be available to use until you end your SparkSession.


# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"

# Execute SQL
spark.sql(sql_str).show()

Yields the same output as above.

7. Complete Example of Groupby Agg (Aggregate)

Following is a complete example of the groupBy() and agg(). This example is also available at GitHub PySpark Examples project for reference.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.printSchema()
df.show(truncate=False)

from pyspark.sql.functions import sum,avg,max,count

# Example 1 - PySpark groupby agg
df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)
    
# Example 3 - Multiple Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

# Example 4 - Using where on Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

# Example 5 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()

8. Frequently Asked Questions on GroupBy aggregate operations

What is PySpark’s groupBy and aggregate operation used for?

PySpark’s groupBy and aggregate operations are used to perform data aggregation and summarization on a DataFrame. They allow you to group data based on one or more columns and then apply various aggregate functions to compute statistics or transformations on the grouped data.

How do I perform a groupBy operation in PySpark?

You can perform a simple groupBy operation by calling the groupBy method on a DataFrame and specifying the column(s) by which you want to group the data. For example: df.groupBy("Column1").

What are some common aggregate functions in PySpark?

PySpark provides a wide range of aggregation functions, including sum, avg, max, min, count, collect_list, collect_set, and many more. These functions allow you to compute various statistics or create new columns based on the grouped data.

How do I use the agg method in PySpark?

In order to use the agg function, first the data needs to be grouped by applying groupBy() method. You can pass a dictionary where the keys are the columns you want to aggregate, and the values are the aggregation functions. For example: grouped.agg({"Value": "sum"}).

9. Conclusion

PySpark DataFrame.groupBy().agg() is used to get the aggregate values like count, sum, avg, min, max for each group. You can also get aggregates per group by using PySpark SQL, in order to use SQL, first you need to create a temporary view.

References

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium