Spark SQL – Flatten Nested Struct column

In Spark SQL, flatten nested struct columns of a DataFrame is simple for one level of the hierarchy and complex when you have multiple levels and hundreds of column. When you have one level of structure you can simply flatten by referring structure by dot notation but when you have a multi-level struct column then things get complex and you need to write a logic to iterate all columns and comes up with a column list to use.

In this article, I will explain how to flatten the nested (single or multi-level) struct column using a Scala example.

Related:

First, let’s create a DataFrame with nested structure column.


val structureData = Seq(
    Row(Row("James ","","Smith"),Row(Row("CA","Los Angles"),Row("CA","Sandiago"))),
    Row(Row("Michael ","Rose",""),Row(Row("NY","New York"),Row("NJ","Newark"))),
    Row(Row("Robert ","","Williams"),Row(Row("DE","Newark"),Row("CA","Las Vegas"))),
    Row(Row("Maria ","Anne","Jones"),Row(Row("PA","Harrisburg"),Row("CA","Sandiago"))),
    Row(Row("Jen","Mary","Brown"),Row(Row("CA","Los Angles"),Row("NJ","Newark")))
  )

val structureSchema = new StructType()
    .add("name",new StructType()
      .add("firstname",StringType)
      .add("middlename",StringType)
      .add("lastname",StringType))
    .add("address",new StructType()
      .add("current",new StructType()
        .add("state",StringType)
        .add("city",StringType))
      .add("previous",new StructType()
        .add("state",StringType)
        .add("city",StringType)))

val df = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
df.printSchema()

df.printSchema() yields below schema.


root
 |-- name: struct (nullable = true)
 |    |-- firstname: string (nullable = true)
 |    |-- middlename: string (nullable = true)
 |    |-- lastname: string (nullable = true)
 |-- address: struct (nullable = true)
 |    |-- current: struct (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- city: string (nullable = true)
 |    |-- previous: struct (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- city: string (nullable = true)

From this example, column “firstname” is the first level of nested structure, and columns “state” and “city” are multi-level structures (meaning 2 or more level deep in the hierarchy).


df.show(false)

Yields below output


+---------------------+----------------------------------+
|name                 |address                           |
+---------------------+----------------------------------+
|[James , , Smith]    |[[CA, Los Angles], [CA, Sandiago]]|
|[Michael , Rose, ]   |[[NY, New York], [NJ, Newark]]    |
|[Robert , , Williams]|[[DE, Newark], [CA, Las Vegas]]   |
|[Maria , Anne, Jones]|[[PA, Harrisburg], [CA, Sandiago]]|
|[Jen, Mary, Brown]   |[[CA, Los Angles], [NJ, Newark]]  |
+---------------------+----------------------------------+

Now, let’s convert it using a simple way. Here, we refer nested structure columns by using dot notation (parentColumn.childColumn)


val df2 = df.select(col("name.*"),
    col("address.current.*"),
    col("address.previous.*"))
val df2Flatten = df2.toDF("fname","mename","lname","currAddState",
    "currAddCity","prevAddState","prevAddCity")
df2Flatten.printSchema()
df2Flatten.show(false)

The above snippet flattens all columns in a DataFrame. Using this approach you can also choose what columns you wanted to flatten.


root
 |-- name_firstname: string (nullable = true)
 |-- name_middlename: string (nullable = true)
 |-- name_lastname: string (nullable = true)
 |-- address_current_state: string (nullable = true)
 |-- address_current_city: string (nullable = true)
 |-- address_previous_state: string (nullable = true)
 |-- address_previous_city: string (nullable = true)

+--------+------+--------+------------+-----------+------------+-----------+
|fname   |mename|lname   |currAddState|currAddCity|prevAddState|prevAddCity|
+--------+------+--------+------------+-----------+------------+-----------+
|James   |      |Smith   |CA          |Los Angles |CA          |Sandiago   |
|Michael |Rose  |        |NY          |New York   |NJ          |Newark     |
|Robert  |      |Williams|DE          |Newark     |CA          |Las Vegas  |
|Maria   |Anne  |Jones   |PA          |Harrisburg |CA          |Sandiago   |
|Jen     |Mary  |Brown   |CA          |Los Angles |NJ          |Newark     |
+--------+------+--------+------------+-----------+------------+-----------+

Since we have limited columns, it seems to be easy by referring column names, but imagine how cumbersome it would be if you have 100+ columns and referring all columns in a select.

Now let’s see a different way where you can easily flatten hundreds of nested level columns. Will do this by creating a nested function flattenStructSchema() which iterates the schema at every level and creates an Array[Column]


def flattenStructSchema(schema: StructType, prefix: String = null) : Array[Column] = {
    schema.fields.flatMap(f => {
      val columnName = if (prefix == null) f.name else (prefix + "." + f.name)

      f.dataType match {
        case st: StructType => flattenStructSchema(st, columnName)
        case _ => Array(col(columnName).as(columnName.replace(".","_")))
      }
    })
  }

To make it simple I still use the same DataFrame from the previous section.


val df3 = df.select(flattenStructSchema(df.schema):_*)
df3.printSchema()
df3.show(false)

Yields below output


+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+
|name.firstname|name.middlename|name.lastname|address.current.state|address.current.city|address.previous.state|address.previous.city|
+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+
|James         |               |Smith        |CA                   |Los Angles          |CA                    |Sandiago             |
|Michael       |Rose           |             |NY                   |New York            |NJ                    |Newark               |
|Robert        |               |Williams     |DE                   |Newark              |CA                    |Las Vegas            |
|Maria         |Anne           |Jones        |PA                   |Harrisburg          |CA                    |Sandiago             |
|Jen           |Mary           |Brown        |CA                   |Los Angles          |NJ                    |Newark               |
+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+

Complete example of Flattening nested struct

This complete code is also available at Spark Example Github project for reference


import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{StringType,StructType}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col

object FlattenNestedStruct extends App {

  val spark: SparkSession = SparkSession.builder()
    .master("local[1]")
    .appName("SparkByExamples.com")
    .getOrCreate()

  val structureData = Seq(
    Row(Row("James ","","Smith"),Row(Row("CA","Los Angles"),Row("CA","Sandiago"))),
    Row(Row("Michael ","Rose",""),Row(Row("NY","New York"),Row("NJ","Newark"))),
    Row(Row("Robert ","","Williams"),Row(Row("DE","Newark"),Row("CA","Las Vegas"))),
    Row(Row("Maria ","Anne","Jones"),Row(Row("PA","Harrisburg"),Row("CA","Sandiago"))),
    Row(Row("Jen","Mary","Brown"),Row(Row("CA","Los Angles"),Row("NJ","Newark")))
  )

  val structureSchema = new StructType()
    .add("name",new StructType()
      .add("firstname",StringType)
      .add("middlename",StringType)
      .add("lastname",StringType))
    .add("address",new StructType()
      .add("current",new StructType()
        .add("state",StringType)
        .add("city",StringType))
      .add("previous",new StructType()
        .add("state",StringType)
        .add("city",StringType)))


  val df = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df.printSchema()
  df.show(false)

  val df2 = df.select(col("name.*"),
    col("address.current.*"),
    col("address.previous.*"))
  df2.toDF("fname","mename","lname","currAddState",
    "currAddCity","prevAddState","prevAddCity")
    .show(false)



  def flattenStructSchema(schema: StructType, prefix: String = null) : Array[Column] = {
    schema.fields.flatMap(f => {
      val columnName = if (prefix == null) f.name else (prefix + "." + f.name)

      f.dataType match {
        case st: StructType => flattenStructSchema(st, columnName)
        case _ => Array(col(columnName).as(columnName.replace(".","_")))
      }
    })
  }

  val df3 = df.select(flattenStructSchema(df.schema):_*)
  df3.printSchema()
  df3.show(false)

Conclusion

In this Spark article you have learned how to flatten nested struct column for simple and complex struct types. Hope you like this.

Happy Learning !!

NNK

SparkByExamples.com is a BigData and Spark examples community page, all examples are simple and easy to understand and well tested in our development environment using Scala and Maven.

Leave a Reply