Spark flatMap()
transformation flattens the RDD/DataFrame column after applying the function on every element and returns a new RDD/DataFrame respectively.
The returned RDD/DataFrame can have the same count or more number of elements. This is one of the major differences between flatMap() and map(), where map() transformation always returns the same number of elements as input.
First, let’s create an RDD by loading the data from a Seq collection.
val data = Seq("Project Gutenberg’s",
"Alice’s Adventures in Wonderland",
"Project Gutenberg’s",
"Adventures in Wonderland",
"Project Gutenberg’s")
val rdd=spark.sparkContext.parallelize(data)
rdd.foreach(println)
This yields the below output.

flatMap() Syntax
flatMap[U](f : scala.Function1[T, scala.TraversableOnce[U]])(implicit evidence$4 : scala.reflect.ClassTag[U]) : org.apache.spark.rdd.RDD[U]
flatMap() Example
Now, let’s see with an example of how to apply a Spark flatMap() transformation on RDD. In the below example, first, it splits each element in RDD by space and finally flattens it.
val rdd1 = rdd.flatMap(f=>f.split(" "))
rdd1.foreach(println)
This yields below output. Resulting RDD consists of a single word on each record.
Project
Gutenberg’s
Alice’s
Adventures
in
Wonderland
Project
Gutenberg’s
Adventures
in
Wonderland
Project
Gutenberg’s
Complete Spark RDD flatMap() example
Below is a complete example Spark RDD flatMap() transformation
import org.apache.spark.sql.SparkSession
object FlatMapExample extends App{
val spark: SparkSession = SparkSession.builder()
.master("local[1]")
.appName("SparkByExamples.com")
.getOrCreate()
val data = Seq("Project Gutenberg’s",
"Alice’s Adventures in Wonderland",
"Project Gutenberg’s",
"Adventures in Wonderland",
"Project Gutenberg’s")
val rdd=spark.sparkContext.parallelize(data)
rdd.foreach(println)
val rdd1 = rdd.flatMap(f=>f.split(" "))
rdd1.foreach(println)
}
Using flatMap() on Spark DataFrame
flatMap() on Spark DataFrame operates similar to RDD, when applied it executes the function specified on every element of the DataFrame by splitting or merging the elements hence, the result count of the flapMap() can be different.
val arrayStructureData = Seq(
Row("James,,Smith",List("Java","Scala","C++"),"CA"),
Row("Michael,Rose,",List("Spark","Java","C++"),"NJ"),
Row("Robert,,Williams",List("CSharp","VB","R"),"NV")
)
val arrayStructureSchema = new StructType()
.add("name",StringType)
.add("languagesAtSchool", ArrayType(StringType))
.add("currentState", StringType)
val df = spark.createDataFrame(
spark.sparkContext.parallelize(arrayStructureData),arrayStructureSchema)
import spark.implicits._
//flatMap() Usage
val df2=df.flatMap(f=> f.getSeq[String](1).map((f.getString(0),_,f.getString(2))))
.toDF("Name","language","State")
df2.show(false)
This yields below output after flatMap() transformation. As you notice the input of the data frame has 3 records but after exploding the “language” using flatMap(), it returns 6 elements.
+----------------+--------+-----+
|Name |Language|State|
+----------------+--------+-----+
|James,,Smith |Java |CA |
|James,,Smith |Scala |CA |
|James,,Smith |C++ |CA |
|Michael,Rose, |Spark |NJ |
|Michael,Rose, |Java |NJ |
|Michael,Rose, |C++ |NJ |
|Robert,,Williams|CSharp |NV |
|Robert,,Williams|VB |NV |
|Robert,,Williams|R |NV |
+----------------+--------+-----+
Conclusion
In conclusion, you have learned syntax and usage of the flatMap() transformation and have seen how it applies a function on every element of a Spark RDD & DataFrame
Related Articles
Reference
Happy Learning !!