Spark map() vs mapPartitions() with Examples

Spark map() and mapPartitions() transformations apply the function on each element/record/row of the DataFrame/Dataset and returns the new DataFrame/Dataset, In this article, I will explain the difference between map() vs mapPartitions() transformations, their syntax, and usages with Scala examples.

  • map() – Spark map() transformation applies a function to each row in a DataFrame/Dataset and returns the new transformed Dataset.
  • mapPartitions() – This is exactly the same as map(); the difference being, Spark mapPartitions() provides a facility to do heavy initializations (for example Database connection) once for each partition instead of doing it on every DataFrame row. This helps the performance of the job when you dealing with heavy-weighted initialization on larger datasets.

Key Points:

  • One key point to remember, these both transformations returns the Dataset[U] but not the DataFrame (In Spark 2.0,  DataFrame = Dataset[Row]) .
  • After applying the transformation function on each row of the input DataFrame/Dataset, these return the same number of rows as input but the schema or number of the columns of the result could be different.
  • If you know flatMap() transformation, this is the key difference between map and flatMap where map returns only one row/element for every input, while flatMap() can return a list of rows/elements.

Spark map() vs mapPartitions() Example

Let’s see the differences with example. First let’s create a Spark DataFrame


  val structureData = Seq(
    Row("James","","Smith","36636","NewYork",3100),
    Row("Michael","Rose","","40288","California",4300),
    Row("Robert","","Williams","42114","Florida",1400),
    Row("Maria","Anne","Jones","39192","Florida",5500),
    Row("Jen","Mary","Brown","34561","NewYork",3000)
  )

  val structureSchema = new StructType()
    .add("firstname",StringType)
    .add("middlename",StringType)
    .add("lastname",StringType)
    .add("id",StringType)
    .add("location",StringType)
    .add("salary",IntegerType)

  val df2 = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df2.printSchema()
  df2.show(false)

Yields below output


root
 |-- firstname: string (nullable = true)
 |-- middlename: string (nullable = true)
 |-- lastname: string (nullable = true)
 |-- id: string (nullable = true)
 |-- location: string (nullable = true)
 |-- salary: integer (nullable = true)

+---------+----------+--------+-----+----------+------+
|firstname|middlename|lastname|id   |location  |salary|
+---------+----------+--------+-----+----------+------+
|James    |          |Smith   |36636|NewYork   |3100  |
|Michael  |Rose      |        |40288|California|4300  |
|Robert   |          |Williams|42114|Florida   |1400  |
|Maria    |Anne      |Jones   |39192|Florida   |5500  |
|Jen      |Mary      |Brown   |34561|NewYork   |3000  |
+---------+----------+--------+-----+----------+------+

In order to explain map() and mapPartitions() with an example, let’s also create a “Util” class with a method combine(), this is a simple method that takes three string arguments and combines them with a comma delimiter. In realtime, this could be a third-party class that does complex transformation.


class Util extends Serializable {
  def combine(fname:String,mname:String,lname:String):String = {
    fname+","+mname+","+lname
  }
}

We will create an object for this class by initializing and call the combine() method for each row in a DataFrame.

Spark map() transformation

Spark map() transformation applies a function to each row in a DataFrame/Dataset and returns the new transformed Dataset. As mentioned earlier, map() returns one row for every row in a input DataFrame, in other words, input and the result exactly contains the same number of rows.

For example, if you have 100 rows in a DataFrame, after applying the function map() return with exactly 100 rows. However, the structure or schema of the result could be different.

Syntax:


1) map[U](func : scala.Function1[T, U])(implicit evidence$6 : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]
2) map[U](func : org.apache.spark.api.java.function.MapFunction[T, U], encoder : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]

Spark provides 2 map transformation signatures one takes scala.function1 as argument and the other takes MapFunction and if you notice both these functions return Dataset[U] but not DataFrame (which is Dataset[Row]). If you want a DataFrame as output then you need to convert the Dataset to DataFrame using toDF() function.

Usage:


import spark.implicits._
val df3 = df2.map(row=>{
  // This initialization happens to every records
  // If it is heavy initilizations like Database connects
  // It degrates the performance 
  val util = new Util() 
  val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
             (fullName, row.getString(3),row.getInt(5))
})
val df3Map =  df3.toDF("fullName","id","salary")

df3Map.printSchema()
df3Map.show(false)

Since map transformations execute on worker nodes, we have initialized and create an object of the Util class inside the map() function and the initialization happens for every row in a DataFrame. This causes performance issues when you have heavily weighted initializations.

Note: When you running it on Standalone mode, initializing the class outside of the map() still works as both executors and driver run on the same JVM but running this on cluster fails with exception.

Above example yields below output.


root
 |-- fullName: string (nullable = true)
 |-- id: string (nullable = true)
 |-- salary: integer (nullable = false)

+----------------+-----+------+
|fullName        |id   |salary|
+----------------+-----+------+
|James,,Smith    |36636|3100  |
|Michael,Rose,   |40288|4300  |
|Robert,,Williams|42114|1400  |
|Maria,Anne,Jones|39192|5500  |
|Jen,Mary,Brown  |34561|3000  |
+----------------+-----+------+

As you notice the above output, the input of the DataFrame has 5 rows so the result of the map also has 5 but the column counts are different.

Spark mapPartitions() transformation

Spark mapPartitions() provides a facility to do heavy initializations (for example Database connection) once for each partition instead of doing it on every DataFrame row. This helps the performance of the job when you dealing with heavy-weighted initialization on larger datasets.

Syntax:


1) mapPartitions[U](func : scala.Function1[scala.Iterator[T], scala.Iterator[U]])(implicit evidence$7 : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]
2) mapPartitions[U](f : org.apache.spark.api.java.function.MapPartitionsFunction[T, U], encoder : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]

map partitions also have 2 signatures, one take scala.Function1 and other takes spark MapPartitionsFunction arguments.

mapPartitions() keeps the result of the partition in-memory until it finishes executing all rows in a partition.

Usage:


  val df4 = df2.mapPartitions(iterator => {
    // Do the heavy initialization here
    // Like database connections e.t.c
    val util = new Util()
    val res = iterator.map(row=>{
      val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
      (fullName, row.getString(3),row.getInt(5))
    })
    res
  })
  val df4part = df4.toDF("fullName","id","salary")
  df4part.printSchema()
  df4part.show(false)

This yields the same output as above.

Complete example of Spark DataFrame map() & mapPartitions()

Below is complete example of Spark DataFrame map() & mapPartition() example.


package com.sparkbyexamples.spark.dataframe.examples

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType,ArrayType,MapType}

# Create util class
class Util extends Serializable {
  def combine(fname:String,mname:String,lname:String):String = {
    fname+","+mname+","+lname
  }
}

# Create object to run
object MapTransformation extends App{

  val spark:SparkSession = SparkSession.builder()
    .master("local[5]")
    .appName("SparkByExamples.com")
    .getOrCreate()

  val structureData = Seq(
    Row("James","","Smith","36636","NewYork",3100),
    Row("Michael","Rose","","40288","California",4300),
    Row("Robert","","Williams","42114","Florida",1400),
    Row("Maria","Anne","Jones","39192","Florida",5500),
    Row("Jen","Mary","Brown","34561","NewYork",3000)
  )

  val structureSchema = new StructType()
    .add("firstname",StringType)
    .add("middlename",StringType)
    .add("lastname",StringType)
    .add("id",StringType)
    .add("location",StringType)
    .add("salary",IntegerType)

  val df2 = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df2.printSchema()
  df2.show(false)

  import spark.implicits._
  val util = new Util()
  val df3 = df2.map(row=>{

    val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
    (fullName, row.getString(3),row.getInt(5))
  })
  val df3Map =  df3.toDF("fullName","id","salary")

  df3Map.printSchema()
  df3Map.show(false)

  val df4 = df2.mapPartitions(iterator => {
    val util = new Util()
    val res = iterator.map(row=>{
      val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
      (fullName, row.getString(3),row.getInt(5))
    })
    res
  })
  val df4part = df4.toDF("fullName","id","salary")
  df4part.printSchema()
  df4part.show(false)
}

This example is also available at Spark Example github project

Conclusion

In this Spark DataFrame article, you have learned map() and mapPartitions() transformations execute a function on each and every row and returns the same number of records as in input but with the same or different schema or columns. Also learned when you have a complex initialization you should be using mapPratitions() as it has the capability to do initializations once for each partition instead of every DataFrame row..

Thanks for reading. Leave me a comment if you like this article.

Happy Learning !!

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium

Leave a Reply

This Post Has 4 Comments

  1. Tharoon

    i tried it it throws error on Util function(not found type util) how to over come this error

    1. NNK

      You need to create Util class in the same package and use it in another program.
      If you add it in a different package, you need to import it.
      To make it easy, I have added the Util class to the complete program at the end. please check and let me know.

  2. PANKAJ

    Good Article. Thanks for sharing

    1. NNK

      Thanks Pankaj.