PySpark collect_list() and collect_set() functions

  • Post author:
  • Post category:PySpark
  • Post last modified:December 18, 2022

PySpark SQL collect_list() and collect_set() functions are used to create an array (ArrayType) column on DataFrame by merging rows, typically after group by or window partitions. I will explain how to use these two functions in this article and learn the differences with examples.

In order to explain these with examples, first let’s create a PySpark DataFrame.


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
                    .getOrCreate()

# Prepare data
data = [('James','Java'),
  ('James','Python'),
  ('James','Python'),
  ('Anna','PHP'),
  ('Anna','Javascript'),
  ('Maria','Java'),
  ('Maria','C++'),
  ('James','Scala'),
  ('Anna','PHP'),
  ('Anna','HTML')
]

# Create DataFrame
df = spark.createDataFrame(data,schema=["name","languages"])
df.printSchema()
df.show()

This yields below output

pyspark collect_list

1. PySpark collect_list() Syntax & Usage

The PySpark function collect_list() is used to aggregate the values into an ArrayType typically after group by and window partition.

1.1 collect_list() Syntax

Following is the syntax of the collect_list()


#Syntax collect_list()
pyspark.sql.functions.collect_list(col)

1.2 collect_list() Examples

In our example, we have a column name and languages, if you see the James like 3 books (1 book duplicated) and Anna likes 3 books (1 book duplicate) Now, let’s say you wanted to group by name and collect all values of languages as an array. This is achieved first by grouping on “name” and aggregating on languages.

Note that colelct_list() collects and includes all duplicates.


# Using collect_list()
from pyspark.sql.functions import collect_list
df2 = df.groupBy("name").agg(collect_list("languages") \
    .alias("languages"))
df2.printSchema()    
df2.show(truncate=False)

This yields the below output.

pyspark collect_list example

2. PySpark collect_set() Syntax & Usage

PySpark SQL function collect_set() is similar to collect_list(). The difference is that collect_set() dedupe or eliminates the duplicates and results in uniqueness for each value.

2.1 collect_set() Syntax

Following is the syntax of the collect_set().


# Syntax of collect_set()
pyspark.sql.functions.collect_set(col)

2.2 Example


# Using collect_set()
from pyspark.sql.functions import collect_set
df2 = df.groupBy("name").agg(collect_set("languages") \
    .alias("languages"))
df2.printSchema()    
df2.show(truncate=False)

This yields the below output.

pyspark collect_set

3. Complete Example

Following is a complete example PySpark collect_list() vs collect_set().


# Import
from pyspark.sql import SparkSession

# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
                    .getOrCreate()
                    
# Prepare data
data = [('James','Java'),
  ('James','Python'),
  ('James','Python'),
  ('Anna','PHP'),
  ('Anna','Javascript'),
  ('Maria','Java'),
  ('Maria','C++'),
  ('James','Scala'),
  ('Anna','PHP'),
  ('Anna','HTML')
]

# Create DataFrame
df = spark.createDataFrame(data,schema=["name","languages"])
df.printSchema()
df.show()

# Using collect_list()
from pyspark.sql.functions import collect_list
df2 = df.groupBy("name").agg(collect_list("languages") \
    .alias("languages"))
df2.printSchema()    
df2.show(truncate=False)
    

# Using collect_set()
from pyspark.sql.functions import collect_set
df2 = df.groupBy("name").agg(collect_set("languages") \
    .alias("languages"))
df2.printSchema()    
df2.show(truncate=False)

4. Conclusion

In summary, PySpark SQL function collect_list() and collect_set() aggregates the data into a list and returns an ArrayType. collect_set() de-dupes the data and return unique values whereas collect_list() return the values as is without eliminating the duplicates.

Leave a Reply