You are currently viewing PySpark Groupby on Multiple Columns

PySpark Groupby on Multiple Columns can be performed either by using a list with the DataFrame column names you wanted to group or by sending multiple column names as parameters to PySpark groupBy() method.

Advertisements

In this article, I will explain how to perform groupby on multiple columns including the use of PySpark SQL and how to use sum(), min(), max(), avg() functions.

1. Prepare DataFrame

Let’s create a PySpark DataFrame.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

Yields below output.

pyspark groupby multiple columns

2. PySpark Groupby on Multiple Columns

Grouping on Multiple Columns in PySpark can be performed by passing two or more columns to the groupBy() method, this returns a pyspark.sql.GroupedData object which contains agg(), sum(), count(), min(), max(), avg() e.t.c to perform aggregations.

When you execute a groupby operation on multiple columns, data with identical keys (combinations of multiple columns) are rearranged and consolidated. Because this process entails shuffling data across the network, group by is considered as a wide transformation, making it a costly operation. Hence, it is advisable to avoid using group by whenever possible.


# groupby multiple columns & count
df.groupBy("department","state").count()
    .show(truncate=False)

Yields below output. This example performs grouping on department and state columns and on the result, I have used the count() method to get the number of records for each group. show() is PySpark function to display the results in the console.

pyspark groupby multiple columns

3. Groupby Multiple Columns from list

In PySpark, we can also use a Python list with multiple column names to the DataFrame.groupBy() method to group records by values of columns from the list. Lists are used to store multiple items in a single variable.

In the below examples group_cols is a list variable holding multiple columns department and state, and pass this list as an argument to groupBy() method.


# groupby multiple columns from list
group_cols = ["department", "state"]
df.groupBy(group_cols).count()
    .show(truncate=False)

Yields the same output as above.

4. Using Agg

Grouping on multiple columns doesn’t complete without explaining performing multiple aggregates at a time using DataFrame.groupBy().agg().


# Using Multiple Aggregates
from pyspark.sql.functions import sum,avg,max
group_cols = ["department", "state"]
df.groupBy(group_cols) \
    .agg(sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) \
    .show(truncate=False)

I will leave this to you to run and explore the result.

5. PySpark SQL GROUP BY Multiple Columns

Lastly, let’s transform the above code into a PySpark SQL query and run it. First using createOrReplaceTempView() create a temporary view/table. Then, utilize SparkSession.sql() to execute the query. The created table will remain accessible until the SparkSession is terminated.


# PySpark SQL Group By AVG, SUM, MAX
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, state, sum(salary) as sum_salary," \
"avg(salary) as avg_salary," \
"sum(bonus) as sum_bonus," \
"max(bonus) as max_bonus" \
" from EMP "  \
" group by department having sum_bonus >= 50000"

# Execute SQL
spark.sql(sql_str).show()

Results in the same output as above.

6. Complete Example

Following is a complete example for your reference.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
        .master("local[5]").getOrCreate()

# Create DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)
df.show(truncate=False)

# groupby multiple columns & count
df.groupBy("department","state").count() \
    .show(truncate=False)

# groupby multiple columns from list
group_cols = ["department", "state"]
df.groupBy(group_cols).count() \
    .show(truncate=False)

# PySpark SQL Group By Count
# Create Temporary table in PySpark
df.createOrReplaceTempView("EMP")

# PySpark SQL
sql_str="select department, state, count(*) as count from EMP " + \
"group by department, state"

# Execute SQL and show the result
spark.sql(sql_str).show()

7. Conclusion

In this article, you have learned to perform PySpark groupby on multiple columns (from list) of DataFrame and also using SQL GROUP BY clause. When you perform group by on multiple columns, the rows having the same key (combination of multiple columns) are shuffled and brought together. Also, groupBy() returns a pyspark.sql.GroupedData object which contains agg(), sum(), count(), min(), max(), avg() e.t.c to perform aggregations.