Spark map() vs flatMap() with Examples

What is the difference between Spark map() vs flatMap() is a most asked interview question, if you are taking an interview on Spark (Java/Scala/PySpark), so let’s understand the differences with examples? Regardless of an interview, you have to know the differences as this is also one of the most used Spark transformations.

  • map() – Spark map() transformation applies a function to each row in a DataFrame/Dataset and returns the new transformed Dataset.
  • flatMap() – Spark flatMap() transformation flattens the DataFrame/Dataset after applying the function on every element and returns a new transformed Dataset. The returned Dataset will return more rows than the current DataFrame. It is also referred to as a one-to-many transformation function. This is one of the major differences between flatMap() and map()

Key points

  • Both map() & flatMap() returns Dataset (DataFrame=Dataset[Row]).
  • Both these transformations are narrow meaning they do not result in Spark Data Shuffle.
  • flatMap() results in redundant data on some columns.
  • One of the use cases of flatMap() is to flatten column which contains arrays, list, or any nested collection(one cell with one value).
  • map() always return the same size/records as in input DataFrame whereas flatMap() returns many records for each record (one-many).

Spark map vs flatMap with Examples

Let’s see the difference with an example. First, let’s create a DataFrame that I will use for map() and flatMap() transformation.


val data = Seq("Project Gutenberg’s",
    "Alice’s Adventures in Wonderland",
    "Project Gutenberg’s",
    "Adventures in Wonderland",
    "Project Gutenberg’s")

import spark.sqlContext.implicits._
val df = data.toDF("data")
df.show()

//Output
+--------------------+
|                data|
+--------------------+
| Project Gutenberg’s|
|Alice’s Adventure...|
| Project Gutenberg’s|
|Adventures in Won...|
| Project Gutenberg’s|
+--------------------+

Spark Map Transformation

Spark map() transformation applies a function to each row in a DataFrame/Dataset and returns the new transformed Dataset.

If you notice the below signatures, both these functions return Dataset[U] but not DataFrame (DataFrame=Dataset[Row]). If you want a DataFrame as output then you need to convert the Dataset to DataFrame using toDF() function.

Syntax:


1) map[U](func : scala.Function1[T, U])(implicit evidence$6 : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]
2) map[U](func : org.apache.spark.api.java.function.MapFunction[T, U], encoder : org.apache.spark.sql.Encoder[U]) 
        : org.apache.spark.sql.Dataset[U]

Example: Here, I have split the value of the column by space using map() transformation, the split() function returns an Array hence the column on DataFrame converted from String to Array Type. you can check this by calling mapDF.printSchema().


//Map Transformation
val mapDF=df.map(fun=> {
    fun.getString(0).split(" ")
})
mapDF.show()

//Output
+-------------------------------------+
|value                                |
+-------------------------------------+
|[Project, Gutenberg’s]               |
|[Alice’s, Adventures, in, Wonderland]|
|[Project, Gutenberg’s]               |
|[Adventures, in, Wonderland]         |
|[Project, Gutenberg’s]               |
+-------------------------------------+

Spark flatMap() Transformation

Spark flatMap() transformation flattens the DataFrame column after applying the function on every element and returns a new DataFrame respectively.

The returned DataFrame can have the same count or more elements than the current DataFrame. This is one of the major differences between flatMap() and map(), where map() transformation always returns the same number of elements as in input.

Syntax:


flatMap[U](f : scala.Function1[T, scala.TraversableOnce[U]])(implicit evidence$4 : scala.reflect.ClassTag[U]) : 
          org.apache.spark.rdd.RDD[U]

Example 1: Like on map() example, on flatMap() also, I have used the split() function and it returns an Array. However, flatMap() converted the array into a row. hence you end up with more records than in input DataFrame.


//Flat Map Transformation
val flatMapDF=df.flatMap(fun=>
  {
     fun.getString(0).split(" ")
  })
flatMapDF.show()

//Output
+-----------+
|      value|
+-----------+
|    Project|
|Gutenberg’s|
|    Alice’s|
| Adventures|
|         in|
| Wonderland|
|    Project|
|Gutenberg’s|
| Adventures|
|         in|
| Wonderland|
|    Project|
|Gutenberg’s|
+-----------+

Example 2:


val arrayStructureData = Seq(
    Row("James,,Smith",List("Java","Scala","C++"),"CA"),
    Row("Michael,Rose,",List("Spark","Java","C++"),"NJ"),
    Row("Robert,,Williams",List("CSharp","VB","R"),"NV")
)

val arrayStructureSchema = new StructType()
    .add("name",StringType)
    .add("languagesAtSchool", ArrayType(StringType))
    .add("currentState", StringType)

val df = spark.createDataFrame(
spark.sparkContext.parallelize(arrayStructureData),arrayStructureSchema)
import spark.implicits._

//flatMap() Usage
val df2=df.flatMap(f => {
    val lang=f.getSeq[String](1)
    lang.map((f.getString(0),_,f.getString(2)))
})

val df3=df2.toDF("Name","language","State")
df3.show(false)

#Outputs
+----------------+--------+-----+
|Name            |Language|State|
+----------------+--------+-----+
|James,,Smith    |Java    |CA   |
|James,,Smith    |Scala   |CA   |
|James,,Smith    |C++     |CA   |
|Michael,Rose,   |Spark   |NJ   |
|Michael,Rose,   |Java    |NJ   |
|Michael,Rose,   |C++     |NJ   |
|Robert,,Williams|CSharp  |NV   |
|Robert,,Williams|VB      |NV   |
|Robert,,Williams|R       |NV   |
+----------------+--------+-----+

Conclusion of Map() vs flatMap()

In this article, you have learned map() and flatMap() are transformations that exists in both RDD and DataFrame. map() transformation is used to transform the data into different values, types by returning the same number of records. flatMap() transformation is used to transform from one record to multiple records.

Happy Learning !!

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium

Leave a Reply

This Post Has 2 Comments

  1. Amrapali Kumari

    Hii, Thanks for the great content in each topic. Can you please provide the same example in python spark.

  2. Vitthal

    Thanks for your valuable information

    Please will you provide that example in spark java..