You are currently viewing Spark – How to slice an array and get a subset of elements

Spark SQL provides a slice() function to get the subset or range of elements from an array (subarray) column of DataFrame and slice function is part of the Spark SQL Array functions group. In this article, I will explain the syntax of the slice() function and it’s usage with a scala example.

Advertisements

In order to use slice function in the Spark DataFrame or Dataset, you have to import SQL function org.apache.spark.sql.functions.slice.

Though I’ve used Scala example here, you can also use the same approach with PySpark (Spark with Python).

1. Slice() function syntax

Slice function can be used by importing org.apache.spark.sql.functions.slice function and below is its syntax.


// Slice() function syntax
slice(x : org.apache.spark.sql.Column, start : scala.Int, length : scala.Int) : org.apache.spark.sql.Column 

slice function takes the first argument as Column of type ArrayType following start of the array index and the number of elements to extract from the array.

Like all Spark SQL functions, slice() function returns a org.apache.spark.sql.Column of ArrayType.

Before we proceed with usage of slice function to get the subset or range of the elements, first, let’s create a DataFrame.


  val spark = SparkSession.builder()
    .appName("SparkByExamples.com")
    .master("local")
    .getOrCreate()

  val arrayStructureData = Seq(
    Row("James,,Smith",List("Java","Scala","C++","Pascal","Spark")),
    Row("Michael,Rose,",List("Spark","Java","C++","Scala","PHP")),
    Row("Robert,,Williams",List("CSharp","VB",".Net","C#.net",""))
  )

  val arrayStructureSchema = new StructType()
    .add("name",StringType)
    .add("languagesAtSchool", ArrayType(StringType))

  val df = spark.createDataFrame(
    spark.sparkContext.parallelize(arrayStructureData),arrayStructureSchema)
  df.show(false)
  df.printSchema()

This yields below output.


// Output:
root
 |-- name: string (nullable = true)
 |-- languages: array (nullable = true)
 |    |-- element: string (containsNull = true)

+----------------+---------------------------------+
|name            |languagesAtSchool                |
+----------------+---------------------------------+
|James,,Smith    |[Java, Scala, C++, Pascal, Spark]|
|Michael,Rose,   |[Spark, Java, C++, Scala, PHP]   |
|Robert,,Williams|[CSharp, VB, .Net, C#.net, ]     |
+----------------+---------------------------------+

2. Slice() function usage

Now, let’s use the slice() SQL function to slice the array and get the subset of elements from an array column.


// Slice() function usage
  val sliceDF = df.withColumn("languages",
    slice(col("languagesAtSchool"),2,3))
    .drop("languagesAtSchool")
  sliceDF.printSchema()
  sliceDF.show(false)

This yields below output


// Output:

+----------------+--------------------+
|name            |languages           |
+----------------+--------------------+
|James,,Smith    |[Scala, C++, Pascal]|
|Michael,Rose,   |[Java, C++, Scala]  |
|Robert,,Williams|[VB, .Net, C#.net]  |
+----------------+--------------------+

3. Using slice() on Spark SQL expression

Since Spark provides a way to execute the raw SQL, let’s learn how to write the same slicing example using Spark SQL expression. In order to use raw SQL, first, you need to create a table using createOrReplaceTempView(). This creates a temporary view from the Dataframe and this view is available lifetime of current Spark context.


df.createOrReplaceTempView("PERSON")
  spark.sql("select name, slice(languagesAtSchool,2,3) as NameArray from PERSON")
    .show(false)

This yields the same output as above example.

4. Complete example

Below is complete example of getting subset of the array elements.


package com.sparkbyexamples.spark.dataframe.functions.collection

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.functions.{array_join, col, slice, split}
import org.apache.spark.sql.types.{ArrayType, StringType, StructType}

object SliceArray extends App {

  val spark = SparkSession.builder()
    .appName("SparkByExamples.com")
    .master("local")
    .getOrCreate()

  val arrayStructureData = Seq(
    Row("James,,Smith",List("Java","Scala","C++","Pascal","Spark")),
    Row("Michael,Rose,",List("Spark","Java","C++","Scala","PHP")),
    Row("Robert,,Williams",List("CSharp","VB",".Net","C#.net",""))
  )

  val arrayStructureSchema = new StructType()
    .add("name",StringType)
    .add("languagesAtSchool", ArrayType(StringType))

  val df = spark.createDataFrame(
    spark.sparkContext.parallelize(arrayStructureData),arrayStructureSchema)
  df.show(false)
  df.printSchema()


  val splitDF2 = df.withColumn("languages",
    slice(col("languagesAtSchool"),2,3))
    .drop("languagesAtSchool")
  splitDF2.printSchema()
  splitDF2.show(false)

  df.createOrReplaceTempView("PERSON")
  spark.sql("select name, slice(languagesAtSchool,2,3) as NameArray from PERSON")
    .show(false)
}

This example is also available at spark-scala-examples GitHub project for reference.

Conclusion

In this simple article, you have learned how to use the slice() function and get the subset or range of the elements from a DataFrame or Dataset array column and also learned how to use slice function on Spark SQL expression.

Happy Learning !!