How to select all other columns when using Groupby in Spark DataFrame? In Spark Scala, there is no direct way if you want to group a DataFrame by one column and add all other columns to the groupBy output


In this article, we shall discuss in detail about groupBy() Dataframe, and how to join this with the original DataFrame to get all columns in the output. This is a two-step process

  • First, group the DataFrame on a column which results the aggregated DataFrame.
  • Second, join this aggregated DataFrame with the original DataFrame to get all the columns in the result

1. Spark groupBy() on DataFrame

In Spark Scala, grouping a DataFrame can be accomplished using the groupBy() method of a DataFrame. This method groups the rows of the DataFrame based on one or more columns and returns a RelationalGroupedDataset object, which can be used to perform various aggregation operations.

// Imports
import org.apache.spark.sql.functions._

// Create DataFrame using toDF and Seq of Data
val df = Seq(
  ("Alice", "Math", 90),
  ("Alice", "Science", 80),
  ("Bob", "Math", 85),
  ("Bob", "Science", 95),
  ("Charlie", "Math", 70),
  ("Charlie", "Science", 75)
).toDF("name", "subject", "score")

// Applying groupBy on name and calculating using Agg functions
val groupedDf = df.groupBy("name").agg(avg("score"), max("score"), count("subject"))

In this example, df is the input DataFrame, which contains three columns: “name”, “subject”, and “score”. The groupBy() method is called on the DataFrame to group the rows by the “name” column. The agg() method is then called on the resulting RelationalGroupedDataset to perform aggregation operations on the grouped data. In this case, we compute the average, maximum, and count of the “score” column for each group.

The resulting groupedDf DataFrame contains the following rows:

// Output:
|   name|avg(score)|max(score)|count(subject)|
|  Alice|      85.0|        90|             2|
|    Bob|      90.0|        95|             2|
|Charlie|      72.5|        75|             2|

Note that we use the avg(), max(), and count() functions from the org.apache.spark.sql.functions package to perform the aggregation operations. The agg() method takes a list of aggregation functions as its arguments.

2. Select all columns after GroupBy

To select all columns in a Spark DataFrame when using the groupBy() function, use the grouped DataFrame and join it with the base DataFrame. This gives you all the columns from the base DataFrame along with the grouped results.

Here’s an example:

# Get all columns
val groupedDf_all_columns = groupedDf.join(df,Seq("name"), "inner")

The resulting groupedDf DataFrame contains the following columns:

|   name|avg(score)|max(score)|count(subject)|subject|score|
|  Alice|      85.0|        90|             2|   Math|   90|
|  Alice|      85.0|        90|             2|Science|   80|
|    Bob|      90.0|        95|             2|   Math|   85|
|    Bob|      90.0|        95|             2|Science|   95|
|Charlie|      72.5|        75|             2|   Math|   70|
|Charlie|      72.5|        75|             2|Science|   75|

Here, we join the base DataFrame df with groupedDf Dataframe by the “name”. As you can see, all columns from the original DataFrame are included along with columns in the groupBy output, and the aggregation functions are applied only to the “score” column.

4. Conclusion

In Spark, selecting all columns of a DataFrame with groupBy can be achieved using the groupBy() and agg() and Join() methods. To group by all columns, simply pass all the columns as arguments to the groupBy() method. Then, use the agg() method to apply aggregation functions to the remaining columns.

To select all columns of the original DataFrame in the groupBy output, you should use the join() function to join the grouped Dataframe and base DataFrame to select all the columns for each group.

Here’s a summary of the steps:

  1. Use groupBy() method to group the DataFrame by all the desired columns.
  2. Use agg() method to apply aggregation functions to the remaining columns.
  3. Use Join() function to join the grouped Dataframe and Base DataFrame, to include all columns of the original DataFrame in the output.

Keep in mind that including all columns in the output can be memory-intensive, especially if the DataFrame has many columns or if the group by operation results in many groups. Therefore, it is generally a good practice to select only the necessary columns in the group by operation.

Related Articles


Data Engineer. I write about BigData Architecture, tools and techniques that are used to build Bigdata pipelines and other generic blogs.