PySpark SQL collect_list()
and collect_set()
functions are used to create an array (ArrayType) column on DataFrame by merging rows, typically after group by or window partitions. I will explain how to use these two functions in this article and learn the differences with examples.
In order to explain these with examples, first let’s create a PySpark DataFrame.
# Import
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
.getOrCreate()
# Prepare data
data = [('James','Java'),
('James','Python'),
('James','Python'),
('Anna','PHP'),
('Anna','Javascript'),
('Maria','Java'),
('Maria','C++'),
('James','Scala'),
('Anna','PHP'),
('Anna','HTML')
]
# Create DataFrame
df = spark.createDataFrame(data,schema=["name","languages"])
df.printSchema()
df.show()
This yields below output
1. PySpark collect_list() Syntax & Usage
The PySpark function collect_list()
is used to aggregate the values into an ArrayType typically after group by and window partition.
1.1 collect_list() Syntax
Following is the syntax of the collect_list()
# Syntax collect_list()
pyspark.sql.functions.collect_list(col)
1.2 collect_list() Examples
In our example, we have a column name and languages, if you see the James like 3 books (1 book duplicated) and Anna likes 3 books (1 book duplicate) Now, let’s say you wanted to group by name and collect all values of languages as an array. This is achieved first by grouping on “name” and aggregating on languages.
Note that colelct_list()
collects and includes all duplicates.
# Using collect_list()
from pyspark.sql.functions import collect_list
df2 = df.groupBy("name").agg(collect_list("languages") \
.alias("languages"))
df2.printSchema()
df2.show(truncate=False)
This yields the below output.
2. PySpark collect_set() Syntax & Usage
PySpark SQL function collect_set()
is similar to collect_list()
. The difference is that collect_set() dedupe or eliminates the duplicates and results in uniqueness for each value.
2.1 collect_set() Syntax
Following is the syntax of the collect_set().
# Syntax of collect_set()
pyspark.sql.functions.collect_set(col)
2.2 Example
# Using collect_set()
from pyspark.sql.functions import collect_set
df2 = df.groupBy("name").agg(collect_set("languages") \
.alias("languages"))
df2.printSchema()
df2.show(truncate=False)
This yields the below output.
3. Complete Example
Following is a complete example PySpark collect_list() vs collect_set().
# Import
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.appName('SparkByExamples.com') \
.getOrCreate()
# Prepare data
data = [('James','Java'),
('James','Python'),
('James','Python'),
('Anna','PHP'),
('Anna','Javascript'),
('Maria','Java'),
('Maria','C++'),
('James','Scala'),
('Anna','PHP'),
('Anna','HTML')
]
# Create DataFrame
df = spark.createDataFrame(data,schema=["name","languages"])
df.printSchema()
df.show()
# Using collect_list()
from pyspark.sql.functions import collect_list
df2 = df.groupBy("name").agg(collect_list("languages") \
.alias("languages"))
df2.printSchema()
df2.show(truncate=False)
# Using collect_set()
from pyspark.sql.functions import collect_set
df2 = df.groupBy("name").agg(collect_set("languages") \
.alias("languages"))
df2.printSchema()
df2.show(truncate=False)
4. Conclusion
In summary, PySpark SQL function collect_list()
and collect_set()
aggregates the data into a list and returns an ArrayType. collect_set() de-dupes the data and return unique values whereas collect_list() return the values as is without eliminating the duplicates.
Related Articles
- PySpark selectExpr() with Example
- How to Convert PySpark Column to List?
- PySpark Create DataFrame from List
- PySpark Create DataFrame from List
- PySpark Apply Function to Column
- PySpark flatMap() Transformation
- PySpark RDD Transformations with examples
- PySpark between() range of values
- PySpark max() – Different Methods Explained
- PySpark sum() Columns Example
- PySpark union two DataFrames
- PySpark Broadcast Variable
- PySpark Broadcast Join
- PySpark persist() Example