Spark map vs mapPartitions transformation

Spark map() and mapPartitions() transformation applies the function on each element/record/row of the DataFrame/Dataset and returns the new DataFrame/Dataset. In this article, I will explain the difference between map() and mapPartitions() transformations, their syntax, and usages with Scala examples.

Note: One key point to remember is these both transformations returns the Dataset[U] but not the DataFrame (In Spark 2.0,  DataFrame = Dataset[Row]) .

After applying the transformation function on each row of the input DataFrame/Dataset, these return the same number of rows as input but the schema or number of the columns of the result could be different.

If you know flatMap() transformation, this is the key difference between map and flatMap where map returns only one row/element for every input, while flatMap() can return a list of rows/elements.

Before we start let’s create a Spark DataFrame


  val structureData = Seq(
    Row("James","","Smith","36636","NewYork",3100),
    Row("Michael","Rose","","40288","California",4300),
    Row("Robert","","Williams","42114","Florida",1400),
    Row("Maria","Anne","Jones","39192","Florida",5500),
    Row("Jen","Mary","Brown","34561","NewYork",3000)
  )

  val structureSchema = new StructType()
    .add("firstname",StringType)
    .add("middlename",StringType)
    .add("lastname",StringType)
    .add("id",StringType)
    .add("location",StringType)
    .add("salary",IntegerType)

  val df2 = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df2.printSchema()
  df2.show(false)

Yields below output


root
 |-- firstname: string (nullable = true)
 |-- middlename: string (nullable = true)
 |-- lastname: string (nullable = true)
 |-- id: string (nullable = true)
 |-- location: string (nullable = true)
 |-- salary: integer (nullable = true)

+---------+----------+--------+-----+----------+------+
|firstname|middlename|lastname|id   |location  |salary|
+---------+----------+--------+-----+----------+------+
|James    |          |Smith   |36636|NewYork   |3100  |
|Michael  |Rose      |        |40288|California|4300  |
|Robert   |          |Williams|42114|Florida   |1400  |
|Maria    |Anne      |Jones   |39192|Florida   |5500  |
|Jen      |Mary      |Brown   |34561|NewYork   |3000  |
+---------+----------+--------+-----+----------+------+

In order to explain map() and mapPartitions() with an example, let’s also create a “Util” class with a method combine(). This is a simple function that takes three string arguments and combines with a comma delimiter. In realtime, this could be a third party class that does complex transformation.


class Util extends Serializable {
  def combine(fname:String,mname:String,lname:String):String = {
    fname+","+mname+","+lname
  }
}

We will create an object for this class by initializing and call the combine() method for each row in a DataFrame.

Spark map() transformation

Spark map() transformation applies a function to each row in a DataFrame/Dataset and returns the new transformed Dataset. As mentioned earlier, map() returns one row for every row in a input DataFrame, in other words, input and the result exactly contains the same number of rows.

For example, if you have 100 rows in a DataFrame, after applying the function map() return with exactly 100 rows. However, the structure or schema of the result could be different.

Syntax:


1) map[U](func : scala.Function1[T, U])(implicit evidence$6 : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]
2) map[U](func : org.apache.spark.api.java.function.MapFunction[T, U], encoder : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]

Spark provides 2 map transformations signatures one takes scala.function1 as argument and other takes Spark MapFunction and if you notice both these functions returns Dataset[U] but not DataFrame which is Dataset[Row]. If you want a DataFrame as output then you need to convert the Dataset to DataFrame using toDF() function.

Usage:


  import spark.implicits._
  val df3 = df2.map(row=>{
    val util = new Util()
    val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
    (fullName, row.getString(3),row.getInt(5))
  })
  val df3Map =  df3.toDF("fullName","id","salary")

  df3Map.printSchema()
  df3Map.show(false)

Since map transformations execute on worker nodes, we have initialized and create an object of the Util class inside the map() function and the initialization happens for every row in a DataFrame. This causes performance issues when you have heavily weighted initializations.

Note: When you running it on Standalone mode, initializing the class outside of the map() still works as both executors and driver run on the same JVM but running this on cluster fails with exception.

Above example yields below output.


root
 |-- fullName: string (nullable = true)
 |-- id: string (nullable = true)
 |-- salary: integer (nullable = false)

+----------------+-----+------+
|fullName        |id   |salary|
+----------------+-----+------+
|James,,Smith    |36636|3100  |
|Michael,Rose,   |40288|4300  |
|Robert,,Williams|42114|1400  |
|Maria,Anne,Jones|39192|5500  |
|Jen,Mary,Brown  |34561|3000  |
+----------------+-----+------+

As you notice the above output, the input of the DataFrame has 5 rows so the result of the map also has 5 but the column counts are different.

Spark mapPartitions() transformation

Spark mapPartitions() provides a facility to do heavy initializations (for example Database connection) once for each partition instead of doing it on every DataFrame row. This helps the performance of the job when you dealing with heavy-weighted initialization on larger datasets.

Syntax:


1) mapPartitions[U](func : scala.Function1[scala.Iterator[T], scala.Iterator[U]])(implicit evidence$7 : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]
2) mapPartitions[U](f : org.apache.spark.api.java.function.MapPartitionsFunction[T, U], encoder : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]

map partitions also has 2 signatures, one take scala.Function1 and other takes spark MapPartitionsFunction arguments.

mapPartitions() keeps the result of the partition in-memory until it finishes executing all rows in a partition.

Usage:


  val df4 = df2.mapPartitions(iterator => {
    val util = new Util()
    val res = iterator.map(row=>{
      val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
      (fullName, row.getString(3),row.getInt(5))
    })
    res
  })
  val df4part = df4.toDF("fullName","id","salary")
  df4part.printSchema()
  df4part.show(false)

This yields the same output as above.

Complete example of Spark DataFrame map() & mapPartitions()

Below is complete example of Spark DataFrame map() & mapPartition() example.


package com.sparkbyexamples.spark.dataframe.examples

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType,ArrayType,MapType}

object MapTransformation extends App{

  val spark:SparkSession = SparkSession.builder()
    .master("local[5]")
    .appName("SparkByExamples.com")
    .getOrCreate()

  val structureData = Seq(
    Row("James","","Smith","36636","NewYork",3100),
    Row("Michael","Rose","","40288","California",4300),
    Row("Robert","","Williams","42114","Florida",1400),
    Row("Maria","Anne","Jones","39192","Florida",5500),
    Row("Jen","Mary","Brown","34561","NewYork",3000)
  )

  val structureSchema = new StructType()
    .add("firstname",StringType)
    .add("middlename",StringType)
    .add("lastname",StringType)
    .add("id",StringType)
    .add("location",StringType)
    .add("salary",IntegerType)

  val df2 = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df2.printSchema()
  df2.show(false)

  import spark.implicits._
  val util = new Util()
  val df3 = df2.map(row=>{

    val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
    (fullName, row.getString(3),row.getInt(5))
  })
  val df3Map =  df3.toDF("fullName","id","salary")

  df3Map.printSchema()
  df3Map.show(false)

  val df4 = df2.mapPartitions(iterator => {
    val util = new Util()
    val res = iterator.map(row=>{
      val fullName = util.combine(row.getString(0),row.getString(1),row.getString(2))
      (fullName, row.getString(3),row.getInt(5))
    })
    res
  })
  val df4part = df4.toDF("fullName","id","salary")
  df4part.printSchema()
  df4part.show(false)
}

This example is also available at Spark Example github project

Conclusion

In this Spark DataFrame article, you have learned map() and mapPartitions() transformations execute a function on each and every row and returns the same number of records as in input but with the same or different schema or columns. Also learned when you have a complex initialization you should be using mapPratitions() as it has the capability to do initializations once for each partition instead of every DataFrame row..

Thanks for reading. Leave me a comment if you like this article.

Happy Learning !!

NNK

SparkByExamples.com is a BigData and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment using Scala and Maven.

This Post Has 2 Comments

  1. PANKAJ

    Good Article. Thanks for sharing

    1. NNK

      Thanks Pankaj.

Leave a Reply