You are currently viewing PySpark Groupby Agg (aggregate) – Explained

PySpark Groupby Agg is used to calculate more than one aggregate (multiple aggregates) at a time on grouped DataFrame. To utilize agg, first, apply the groupBy() to the DataFrame, which organizes the records based on single or multiple-column values. Subsequently, use agg() on the result of groupBy() to obtain the aggregate values for each group.

Advertisements

In this article, I will explain agg() function on grouped DataFrame with examples. In PySpark, the groupBy() function gathers similar data into groups, while the agg() function is then utilized to execute various aggregations such as count, sum, average, minimum, maximum, and others on the grouped data.

Before proceeding with these examples, let’s generate the DataFrame from a sequence of data. This DataFrame includes columns such as “employee_name“, “department“, “state“, “salary“, “age“, and “bonus


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') 
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

Yields below output.

pyspark groupby agg

1. PySpark Groupby Aggregate Example

Use DataFrame.groupBy().agg() in PySpark to calculate the total number of rows for each group by specifying the aggregate function count. DataFrame.groupBy() function returns a pyspark.sql.GroupedData and agg() function is a method from the GroupedData class.

After performing aggregates this function returns a PySpark DataFrame.


# Syntax
GroupedData.agg(*exprs)

To use aggregate functions like sum(), avg(), min(), max() e.t.c you have to import from pyspark.sql.functions. In the below example I am calculating the number of rows for each group by grouping on the department column and using agg() and count() function.


# PySpark groupBy() agg
from pyspark.sql.functions import count

df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

Yields below output.

pyspark groupby aggregate

2. Groupby Agg on Two or More Columns

In PySpark, conducting Groupby Aggregate on Multiple Columns involves supplying two or more columns to the groupBy() and utilizing agg(). In the subsequent example, grouping is executed based on the “department” and “state” columns, and within agg(), the count() function is used.


from pyspark.sql.functions import count

# groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")) \
    .show(truncate=False)

Output:

groupby agg multiple columns

3. Running Multiple Aggregates at a time

Executing Multiple Aggregates Simultaneously within a single statement can be achieved by using PySpark SQL aggregate functions sum(), avg(), min(), max() mean(), count(), etc. To utilize these functions, it’s necessary to import them using “from pyspark.sql.functions import sum, avg, max, min, mean, count”.


from pyspark.sql.functions import sum,avg,max

df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

In this example, grouping is performed based on the “department” column. It calculates the sum() and avg() of the “salary” for each department and the sum() and max() of the bonus for each department.

pyspark groupby agg count

4. Using Where on Aggregate DataFrame

Use the where() function to filter rows based on aggregated data, similar to SQL’s “HAVING” clause. This functionality is essential for selecting or excluding data based on specified conditions.


# Using groupBy(), agg() and where()
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

Yields below output.

5. PySpark SQL GROUP BY & HAVING

Lastly, let’s use the PySpark SQL query and execute the above example. As a first step, you need to create a temporary view using createOrReplaceTempView(), then use the SparkSession.sql() to execute the SQL query.


# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"

# Execute SQL
spark.sql(sql_str).show()

Result have the same output as above.

6. Complete Example of Groupby Agg (Aggregate)

This example is also available at GitHub PySpark Examples project for reference.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.printSchema()
df.show(truncate=False)

from pyspark.sql.functions import sum,avg,max,count

# Example 1 - PySpark groupby agg
df.groupBy("department") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)

# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
    .agg(count("*").alias("count")
     ) \
    .show(truncate=False)
    
# Example 3 - Multiple Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

# Example 4 - Using where on Aggregates
df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000) \
    .show(truncate=False)

# Example 5 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()

7. Frequently Asked Questions on GroupBy aggregate operations

What is PySpark’s groupBy and aggregate operation used for?

PySpark’s groupBy and aggregate operations are used to perform data aggregation and summarization on a DataFrame. They allow you to group data based on one or more columns and then apply various aggregate functions to compute statistics or transformations on the grouped data.

How do I perform a groupBy operation in PySpark?

You can perform a simple groupBy operation by calling the groupBy method on a DataFrame and specifying the column(s) by which you want to group the data. For example: df.groupBy("Column1").

What are some common aggregate functions in PySpark?

PySpark provides a wide range of aggregation functions, including sum, avg, max, min, count, collect_list, collect_set, and many more. These functions allow you to compute various statistics or create new columns based on the grouped data.

How do I use the agg method in PySpark?

In order to use the agg function, first the data needs to be grouped by applying groupBy() method. You can pass a dictionary where the keys are the columns you want to aggregate, and the values are the aggregation functions. For example: grouped.agg({"Value": "sum"}).

8. Conclusion

PySpark DataFrame.groupBy().agg() is used to get the aggregate values like count, sum, avg, min, max for each group. You can also get aggregates per group by using PySpark SQL, in order to use SQL, first you need to create a temporary view.

References