You are currently viewing PySpark repartition() – Explained with Examples

pyspark.sql.DataFrame.repartition() method is used to increase or decrease the RDD/DataFrame partitions by number of partitions or by single column name or multiple column names. This function takes 2 parameters; numPartitions and *cols, when one is specified the other is optional. repartition() is a wider transformation that involves shuffling of the data hence, it is considered an expensive operation.


Key Points –

  • repartition() is used to increase or decrease the number of partitions.
  • repartition() creates even partitions when compared with coalesce().
  • It is a wider transformation.
  • It is an expensive operation as it involves data shuffle and consumes more resources.
  • repartition() can take int or column names as param to define how to perform the partitions.
  • If parameters are not specified, it uses the default number of partitions.
  • As part of performance optimization, recommends avoiding using this function.

1. Quick Examples of PySpark repartition()

Following are quick examples of PySpark repartition() of DataFrame.

# Repartition by number
df2 = df.repartition(5)

# Repatition by column name
df2 = df.repartition("state")

# Repatition by column name
df2 = df.repartition(5, "state")

# Repatition by multiple columns
df2 = df.repartition("state","department")

2. DataFrame.repartition()

repartition() is a method of pyspark.sql.DataFrame class that is used to increase or decrease the number of partitions of the DataFrame. When you create a DataFrame, the data or rows are distributed across multiple partitions across many servers. so repartition data into different fewer or higher partitions use this method.

2.1 Syntax

Following is the syntax of DataFrame.repartition()

# Syntax of repartition()
DataFrame.repartition(numPartitions, *cols)

2.2 Parameters & Return Type

Following are the parameters of repartition() and this returns a DataFrame with repartitioned data.

  • numPartitions – Target Number of partitions. If not specified the default number of partitions is used.
  • *cols – Single or multiple columns to use in repartition.

3. PySpark DataFrame repartition()

The repartition re-distributes the data from all partitions into a specified number of partitions which leads to a full data shuffle which is a very expensive operation when you have billions or trillions of data. To see how it works, let’s create a DataFrame with some test data.

# Imports
from pyspark.sql import SparkSession

# Create Spark Session
spark = SparkSession.builder.appName('') \

# Create PySpark DataFrame
simpleData = [("James","Sales","NY",90000,34,10000),

schema = ["employee_name","department","state","salary","age","bonus"]
df = spark.createDataFrame(data=simpleData, schema = schema)

# Write to CSV file

By default, DataFrame is created with default partitions based on your environment and setup. The following example creates multiple part files at the specified location. The number of partitions = number of part files. Note that if you run this multiple times, you will get different values in part files for each run.

PySpark repartition

3.1 Repartition by Number

Now let’s repartition this data to 3 partitions by sending value 3 to numPartitions param.

# repartition()
df2 = df.repartition(numPartitions=3)

# Write DataFrame to CSV file

It repartitions the DataFrame into 3 partitions.

3.2 Repartition by Column

Using repartition() method you can also do the PySpark DataFrame partition by single column name, or multiple columns. Let’s repartition the PySpark DataFrame by column, in the following example, repartition() re-distributes the data by column name state.

# repartition by column
df2 = df.repartition("state")

# Write

3.3. Repartition by Multiple Columns

Let’s repartition the PySpark DataFrame by multiple columns. The following example re-distributes the data by columns state and department.

# repartition by multiple column
df2 = df.repartition("state","department")

# Write

4. PySpark RDD Repartition

In RDD, you can create parallelism at the time of the creation of an RDD using parallelize(), textFile() and wholeTextFiles().

rdd = spark.sparkContext.parallelize((0,20))
print("From local[5]"+str(rdd.getNumPartitions()))

rdd1 = spark.sparkContext.parallelize((0,25), 6)
print("parallelize : "+str(rdd1.getNumPartitions()))

rddFromFile = spark.sparkContext.textFile("/tmp/test.txt",10)
print("TextFile : "+str(rddFromFile.getNumPartitions()))


From local[5] : 5
Parallelize : 6
TextFile : 10

spark.sparkContext.parallelize(Range(0,20),6) distributes RDD into 6 partitions and the data is distributed as below.

//Writes 6 part files, one for each partition
Partition 1 : 0 1 2
Partition 2 : 3 4 5
Partition 3 : 6 7 8 9
Partition 4 : 10 11 12
Partition 5 : 13 14 15
Partition 6 : 16 17 18 19

1.1 RDD repartition()

The repartition() method in PySpark RDD redistributes data across partitions, increasing or decreasing the number of partitions as specified. This operation triggers a full shuffle of the data, which involves moving data across the cluster, potentially resulting in a costly operation.

The below example decreases the partitions from 10 to 4 by moving data from all partitions.

rdd2 = rdd1.repartition(4)
print("Repartition size : "+str(rdd2.getNumPartitions()))

This yields output Repartition size : 4 and the repartition re-distributes the data(as shown below) from all partitions which is a full shuffle leading to a very expensive operation when dealing with billions and trillions of data.

Partition 1 : 1 6 10 15 19
Partition 2 : 2 3 7 11 16
Partition 3 : 4 8 12 13 17
Partition 4 : 0 5 9 14 18


You have learned the advantages and disadvantages of using the PySpark repartition() function which does the re-distribution of RDD/DataFrame data into lower or higher numbers.

Related Articles