• Post author:
  • Post category:PySpark
  • Post last modified:March 27, 2024
  • Reading time:10 mins read
You are currently viewing PySpark Groupby on Multiple Columns

PySpark Groupby on Multiple Columns can be performed either by using a list with the DataFrame column names you wanted to group or by sending multiple column names as parameters to PySpark groupBy() method.

In this article, I will explain how to perform groupby on multiple columns including the use of PySpark SQL and how to use sum(), min(), max(), avg() functions.

1. Quick Examples of Groupby on Multiple Columns

The following are quick examples of how to groupby on multiple columns.


# Quick Examples of PySpark Groupby Multiple Columns

# Example 1: groupby multiple columns & count
df.groupBy("department","state").count() \
    .show(truncate=False)

# Example 2: groupby multiple columns from list
group_cols = ["department", "state"]
df.groupBy(group_cols).count() \
    .show(truncate=False)

# Example 3: Using Multiple Aggregates
from pyspark.sql.functions import sum,avg,max
group_cols = ["department", "state"]
df.groupBy(group_cols) \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

# Example 4: PySpark SQL Group By
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, state, count(*) as count from EMP " + \
"group by department, state"

# Execute SQL and show the result
spark.sql(sql_str).show()

Let’s create a PySpark DataFrame.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

Yields below output.

pyspark groupby multiple columns

2. PySpark Groupby on Multiple Columns

Grouping on Multiple Columns in PySpark can be performed by passing two or more columns to the groupBy() method, this returns a pyspark.sql.GroupedData object which contains agg(), sum(), count(), min(), max(), avg() e.t.c to perform aggregations.

When you perform group by on multiple columns, the data having the same key (combination of multiple columns) are shuffled and brought together. Since it involves the data shuffling across the network, group by is considered a wider transformation hence, it is an expensive operation and you should ignore it when you can.


# groupby multiple columns & count
df.groupBy("department","state").count()
    .show(truncate=False)

Yields below output. This example performs grouping on department and state columns and on the result, I have used the count() method to get the number of records for each group. show() is PySpark function to display the results in the console.

pyspark groupby multiple columns

3. Groupby Multiple Columns from list

In PySpark, we can also use a Python list with multiple column names to the DataFrame.groupBy() method to group records by values of columns from the list. Lists are used to store multiple items in a single variable.

In the below examples group_cols is a list variable holding multiple columns department and state, and pass this list as an argument to groupBy() method.


# groupby multiple columns from list
group_cols = ["department", "state"]
df.groupBy(group_cols).count()
    .show(truncate=False)

Yields the same output as above.

4. Using Agg

Grouping on multiple columns doesn’t complete without explaining performing multiple aggregates at a time using DataFrame.groupBy().agg().


# Using Multiple Aggregates
from pyspark.sql.functions import sum,avg,max
group_cols = ["department", "state"]
df.groupBy(group_cols) \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

I will leave this to you to run and explore the result.

5. PySpark SQL GROUP BY Multiple Columns

Finally, let’s convert the above code into the PySpark SQL query and execute it. In order to do so, first, you need to create a temporary view by using createOrReplaceTempView() and use SparkSession.sql() to run the query. The table would be available to use until you end your SparkSession.


# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, state, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"

# Execute SQL
spark.sql(sql_str).show()

Yields the same output as above.

6. Complete Example

Following is a complete example of groupby Multiple columns.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

# groupby multiple columns & count
df.groupBy("department","state").count() \
    .show(truncate=False)

# groupby multiple columns from list
group_cols = ["department", "state"]
df.groupBy(group_cols).count() \
    .show(truncate=False)

# PySpark SQL Group By Count
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, state, count(*) as count from EMP " + \
"group by department, state"

# Execute SQL and show the result
spark.sql(sql_str).show()

7. Conclusion

In this article, you have learned to perform PySpark groupby on multiple columns (from list) of DataFrame and also using SQL GROUP BY clause. When you perform group by on multiple columns, the rows having the same key (combination of multiple columns) are shuffled and brought together. Also, groupBy() returns a pyspark.sql.GroupedData object which contains agg(), sum(), count(), min(), max(), avg() e.t.c to perform aggregations.

Related Articles

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium