PySpark Groupby Agg is used to calculate more than one aggregate (multiple aggregates) at a time on grouped DataFrame. To utilize agg, first, apply the groupBy() to the DataFrame, which organizes the records based on single or multiple-column values. Subsequently, use agg() on the result of groupBy() to obtain the aggregate values for each group.
In this article, I will explain agg() function on grouped DataFrame with examples. In PySpark, the groupBy() function gathers similar data into groups, while the agg() function is then utilized to execute various aggregations such as count, sum, average, minimum, maximum, and others on the grouped data.
Before proceeding with these examples, let’s generate the DataFrame from a sequence of data. This DataFrame includes columns such as “employee_name
“, “department
“, “state
“, “salary
“, “age
“, and “bonus
“
# Import
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com')
.master("local[5]").getOrCreate()
# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
("Michael","Sales","NY",86000,56,20000),
("Robert","Sales","CA",81000,30,23000),
("Maria","Finance","CA",90000,24,23000),
("Raman","Finance","CA",99000,40,24000),
("Scott","Finance","NY",83000,36,19000),
("Jen","Finance","NY",79000,53,15000),
("Jeff","Marketing","CA",80000,25,18000),
("Kumar","Marketing","NY",91000,50,21000)
]
schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)
Yields below output.
1. PySpark Groupby Aggregate Example
Use DataFrame.groupBy().agg()
in PySpark to calculate the total number of rows for each group by specifying the aggregate function count. DataFrame.groupBy() function returns a pyspark.sql.GroupedData and agg() function is a method from the GroupedData class.
After performing aggregates this function returns a PySpark DataFrame.
# Syntax
GroupedData.agg(*exprs)
To use aggregate functions like sum()
, avg()
, min()
, max()
e.t.c you have to import from pyspark.sql.functions
. In the below example I am calculating the number of rows for each group by grouping on the department
column and using agg() and count() function.
# PySpark groupBy() agg
from pyspark.sql.functions import count
df.groupBy("department") \
.agg(count("*").alias("count")
) \
.show(truncate=False)
Yields below output.
2. Groupby Agg on Two or More Columns
In PySpark, conducting Groupby Aggregate on Multiple Columns involves supplying two or more columns to the groupBy() and utilizing agg(). In the subsequent example, grouping is executed based on the “department” and “state” columns, and within agg(), the count() function is used.
from pyspark.sql.functions import count
# groupby multiple columns & agg
df.groupBy("department","state") \
.agg(count("*").alias("count")) \
.show(truncate=False)
Output:
3. Running Multiple Aggregates at a time
Executing Multiple Aggregates Simultaneously within a single statement can be achieved by using PySpark SQL aggregate functions sum(), avg(), min(), max() mean(), count(), etc. To utilize these functions, it’s necessary to import them using “from pyspark.sql.functions import sum, avg, max, min, mean, count”.
from pyspark.sql.functions import sum,avg,max
df.groupBy("department") \
.agg(sum("salary").alias("sum_salary"), \
avg("salary").alias("avg_salary"), \
sum("bonus").alias("sum_bonus"), \
max("bonus").alias("max_bonus") \
) \
.show(truncate=False)
In this example, grouping is performed based on the “department” column. It calculates the sum() and avg() of the “salary” for each department and the sum() and max() of the bonus for each department.
4. Using Where on Aggregate DataFrame
Use the where() function to filter rows based on aggregated data, similar to SQL’s “HAVING” clause. This functionality is essential for selecting or excluding data based on specified conditions.
# Using groupBy(), agg() and where()
df.groupBy("department") \
.agg(sum("salary").alias("sum_salary"), \
avg("salary").alias("avg_salary"), \
sum("bonus").alias("sum_bonus"), \
max("bonus").alias("max_bonus")) \
.where(col("sum_bonus") >= 50000) \
.show(truncate=False)
Yields below output.
5. PySpark SQL GROUP BY & HAVING
Lastly, let’s use the PySpark SQL query and execute the above example. As a first step, you need to create a temporary view using createOrReplaceTempView(), then use the SparkSession.sql() to execute the SQL query.
# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")
# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP " \
" group by department having sum_bonus >= 50000"
# Execute SQL
spark.sql(sql_str).show()
Result have the same output as above.
6. Complete Example of Groupby Agg (Aggregate)
This example is also available at GitHub PySpark Examples project for reference.
# Import
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
.master("local[5]").getOrCreate()
# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
("Michael","Sales","NY",86000,56,20000),
("Robert","Sales","CA",81000,30,23000),
("Maria","Finance","CA",90000,24,23000),
("Raman","Finance","CA",99000,40,24000),
("Scott","Finance","NY",83000,36,19000),
("Jen","Finance","NY",79000,53,15000),
("Jeff","Marketing","CA",80000,25,18000),
("Kumar","Marketing","NY",91000,50,21000)
]
schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.printSchema()
df.show(truncate=False)
from pyspark.sql.functions import sum,avg,max,count
# Example 1 - PySpark groupby agg
df.groupBy("department") \
.agg(count("*").alias("count")
) \
.show(truncate=False)
# Example 2 - groupby multiple columns & agg
df.groupBy("department","state") \
.agg(count("*").alias("count")
) \
.show(truncate=False)
# Example 3 - Multiple Aggregates
df.groupBy("department") \
.agg(sum("salary").alias("sum_salary"), \
avg("salary").alias("avg_salary"), \
sum("bonus").alias("sum_bonus"), \
max("bonus").alias("max_bonus") \
) \
.show(truncate=False)
# Example 4 - Using where on Aggregates
df.groupBy("department") \
.agg(sum("salary").alias("sum_salary"), \
avg("salary").alias("avg_salary"), \
sum("bonus").alias("sum_bonus"), \
max("bonus").alias("max_bonus")) \
.where(col("sum_bonus") >= 50000) \
.show(truncate=False)
# Example 5 - SQL group by agg
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")
# PySpark SQL
sql_str="select department, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP " \
" group by department having sum_bonus >= 50000"
spark.sql(sql_str).show()
7. Frequently Asked Questions on GroupBy aggregate operations
PySpark’s groupBy and aggregate operations are used to perform data aggregation and summarization on a DataFrame. They allow you to group data based on one or more columns and then apply various aggregate functions to compute statistics or transformations on the grouped data.
You can perform a simple groupBy operation by calling the groupBy
method on a DataFrame and specifying the column(s) by which you want to group the data. For example: df.groupBy("Column1")
.
PySpark provides a wide range of aggregation functions, including sum
, avg
, max
, min
, count
, collect_list
, collect_set
, and many more. These functions allow you to compute various statistics or create new columns based on the grouped data.
agg
method in PySpark? In order to use the agg function, first the data needs to be grouped by applying groupBy() method. You can pass a dictionary where the keys are the columns you want to aggregate, and the values are the aggregation functions. For example: grouped.agg({"Value": "sum"})
.
8. Conclusion
PySpark DataFrame.groupBy().agg() is used to get the aggregate values like count, sum, avg, min, max for each group. You can also get aggregates per group by using PySpark SQL, in order to use SQL, first you need to create a temporary view.
Related Articles
- PySpark Column alias after groupBy() Example
- PySpark DataFrame groupBy and Sort by Descending Order
- PySpark Count of Non null, nan Values in DataFrame
- PySpark Count Distinct from DataFrame
- PySpark sum() Columns Example
- PySpark – Find Count of null, None, NaN Values
- PySpark Groupby Count Distinct
- PySpark Groupby on Multiple Columns
- PySpark GroupBy Count – Explained