Spark SQL – Flatten Nested Struct Column

In Spark SQL, flatten nested struct column (convert struct to columns) of a DataFrame is simple for one level of the hierarchy and complex when you have multiple levels and hundreds of columns. When you have one level of structure you can simply flatten by referring structure by dot notation but when you have a multi-level struct column then things get complex and you need to write a logic to iterate all columns and comes up with a column list to use.

In this article, I will explain how to convert/flatten the nested (single or multi-level) struct column using a Scala example.

Related:

First, let’s create a DataFrame with nested structure column.


val structureData = Seq(
    Row(Row("James ","","Smith"),Row(Row("CA","Los Angles"),Row("CA","Sandiago"))),
    Row(Row("Michael ","Rose",""),Row(Row("NY","New York"),Row("NJ","Newark"))),
    Row(Row("Robert ","","Williams"),Row(Row("DE","Newark"),Row("CA","Las Vegas"))),
    Row(Row("Maria ","Anne","Jones"),Row(Row("PA","Harrisburg"),Row("CA","Sandiago"))),
    Row(Row("Jen","Mary","Brown"),Row(Row("CA","Los Angles"),Row("NJ","Newark")))
  )

val structureSchema = new StructType()
    .add("name",new StructType()
      .add("firstname",StringType)
      .add("middlename",StringType)
      .add("lastname",StringType))
    .add("address",new StructType()
      .add("current",new StructType()
        .add("state",StringType)
        .add("city",StringType))
      .add("previous",new StructType()
        .add("state",StringType)
        .add("city",StringType)))

val df = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
df.printSchema()

df.printSchema() yields below schema.


root
 |-- name: struct (nullable = true)
 |    |-- firstname: string (nullable = true)
 |    |-- middlename: string (nullable = true)
 |    |-- lastname: string (nullable = true)
 |-- address: struct (nullable = true)
 |    |-- current: struct (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- city: string (nullable = true)
 |    |-- previous: struct (nullable = true)
 |    |    |-- state: string (nullable = true)
 |    |    |-- city: string (nullable = true)

From this example, column “firstname” is the first level of nested structure, and columns “state” and “city” are multi-level structures (meaning 2 or more level deep in the hierarchy).


df.show(false)

Yields below output


+---------------------+----------------------------------+
|name                 |address                           |
+---------------------+----------------------------------+
|[James , , Smith]    |[[CA, Los Angles], [CA, Sandiago]]|
|[Michael , Rose, ]   |[[NY, New York], [NJ, Newark]]    |
|[Robert , , Williams]|[[DE, Newark], [CA, Las Vegas]]   |
|[Maria , Anne, Jones]|[[PA, Harrisburg], [CA, Sandiago]]|
|[Jen, Mary, Brown]   |[[CA, Los Angles], [NJ, Newark]]  |
+---------------------+----------------------------------+

Now, let’s convert it using a simple way. Here, we refer nested struct columns by using dot notation (parentColumn.childColumn)


val df2 = df.select(col("name.*"),
    col("address.current.*"),
    col("address.previous.*"))
val df2Flatten = df2.toDF("fname","mename","lname","currAddState",
    "currAddCity","prevAddState","prevAddCity")
df2Flatten.printSchema()
df2Flatten.show(false)

The above snippet flattens all columns in a DataFrame. Using this approach you can also choose what columns you wanted to flatten.


root
 |-- name_firstname: string (nullable = true)
 |-- name_middlename: string (nullable = true)
 |-- name_lastname: string (nullable = true)
 |-- address_current_state: string (nullable = true)
 |-- address_current_city: string (nullable = true)
 |-- address_previous_state: string (nullable = true)
 |-- address_previous_city: string (nullable = true)

+--------+------+--------+------------+-----------+------------+-----------+
|fname   |mename|lname   |currAddState|currAddCity|prevAddState|prevAddCity|
+--------+------+--------+------------+-----------+------------+-----------+
|James   |      |Smith   |CA          |Los Angles |CA          |Sandiago   |
|Michael |Rose  |        |NY          |New York   |NJ          |Newark     |
|Robert  |      |Williams|DE          |Newark     |CA          |Las Vegas  |
|Maria   |Anne  |Jones   |PA          |Harrisburg |CA          |Sandiago   |
|Jen     |Mary  |Brown   |CA          |Los Angles |NJ          |Newark     |
+--------+------+--------+------------+-----------+------------+-----------+

Since we have limited columns, it seems to be easy by referring column names, but imagine how cumbersome it would be if you have 100+ columns and referring all columns in a select.

Now let’s see a different way where you can easily flatten hundreds of nested level columns. Will do this by creating a nested function flattenStructSchema() which iterates the schema at every level and creates an Array[Column]


def flattenStructSchema(schema: StructType, prefix: String = null) : Array[Column] = {
    schema.fields.flatMap(f => {
      val columnName = if (prefix == null) f.name else (prefix + "." + f.name)

      f.dataType match {
        case st: StructType => flattenStructSchema(st, columnName)
        case _ => Array(col(columnName).as(columnName.replace(".","_")))
      }
    })
  }

To make it simple I still use the same DataFrame from the previous section.


val df3 = df.select(flattenStructSchema(df.schema):_*)
df3.printSchema()
df3.show(false)

Yields below output


+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+
|name.firstname|name.middlename|name.lastname|address.current.state|address.current.city|address.previous.state|address.previous.city|
+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+
|James         |               |Smith        |CA                   |Los Angles          |CA                    |Sandiago             |
|Michael       |Rose           |             |NY                   |New York            |NJ                    |Newark               |
|Robert        |               |Williams     |DE                   |Newark              |CA                    |Las Vegas            |
|Maria         |Anne           |Jones        |PA                   |Harrisburg          |CA                    |Sandiago             |
|Jen           |Mary           |Brown        |CA                   |Los Angles          |NJ                    |Newark               |
+--------------+---------------+-------------+---------------------+--------------------+----------------------+---------------------+

Complete Example of Convert Struct to Columns

This complete code is also available at Spark Example Github project for reference


import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.types.{StringType,StructType}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.col

object FlattenNestedStruct extends App {

  val spark: SparkSession = SparkSession.builder()
    .master("local[1]")
    .appName("SparkByExamples.com")
    .getOrCreate()

  val structureData = Seq(
    Row(Row("James ","","Smith"),Row(Row("CA","Los Angles"),Row("CA","Sandiago"))),
    Row(Row("Michael ","Rose",""),Row(Row("NY","New York"),Row("NJ","Newark"))),
    Row(Row("Robert ","","Williams"),Row(Row("DE","Newark"),Row("CA","Las Vegas"))),
    Row(Row("Maria ","Anne","Jones"),Row(Row("PA","Harrisburg"),Row("CA","Sandiago"))),
    Row(Row("Jen","Mary","Brown"),Row(Row("CA","Los Angles"),Row("NJ","Newark")))
  )

  val structureSchema = new StructType()
    .add("name",new StructType()
      .add("firstname",StringType)
      .add("middlename",StringType)
      .add("lastname",StringType))
    .add("address",new StructType()
      .add("current",new StructType()
        .add("state",StringType)
        .add("city",StringType))
      .add("previous",new StructType()
        .add("state",StringType)
        .add("city",StringType)))


  val df = spark.createDataFrame(
    spark.sparkContext.parallelize(structureData),structureSchema)
  df.printSchema()
  df.show(false)

  val df2 = df.select(col("name.*"),
    col("address.current.*"),
    col("address.previous.*"))
  df2.toDF("fname","mename","lname","currAddState",
    "currAddCity","prevAddState","prevAddCity")
    .show(false)



  def flattenStructSchema(schema: StructType, prefix: String = null) : Array[Column] = {
    schema.fields.flatMap(f => {
      val columnName = if (prefix == null) f.name else (prefix + "." + f.name)

      f.dataType match {
        case st: StructType => flattenStructSchema(st, columnName)
        case _ => Array(col(columnName).as(columnName.replace(".","_")))
      }
    })
  }

  val df3 = df.select(flattenStructSchema(df.schema):_*)
  df3.printSchema()
  df3.show(false)

Conclusion

In this Spark article, you have learned how to flatten nested struct column (convert struct to columns) for simple and complex struct types. Hope you like this.

Happy Learning !!

Naveen Nelamali

Naveen Nelamali (NNK) is a Data Engineer with 20+ years of experience in transforming data into actionable insights. Over the years, He has honed his expertise in designing, implementing, and maintaining data pipelines with frameworks like Apache Spark, PySpark, Pandas, R, Hive and Machine Learning. Naveen journey in the field of data engineering has been a continuous learning, innovation, and a strong commitment to data integrity. In this blog, he shares his experiences with the data as he come across. Follow Naveen @ LinkedIn and Medium

Leave a Reply

This Post Has 3 Comments

  1. DJL

    HI thanks for posting this it’s really helpful.

    1 question though – how can we expand this to handle arrays within a column if thats possible?

  2. Anonymous

    nice article… i’m preparing for interview with sparkbyexamples… really helps to understand concepts… great and generous work.. thanks a lot

    1. NNK

      Thanks and Good luck with your interview 🙂