• Post author:
  • Post category:Polars
  • Post last modified:December 31, 2024
  • Reading time:14 mins read

In Polars, the pivot() method is used to reshape a DataFrame by pivoting one or more columns’ unique values into separate columns, allowing for an aggregation of the data based on some index and columns. This method is ideal for converting long-form data into wide-form data while summarizing or aggregating values.

Advertisements

In this article, I will explain the Polars DataFrame pivot() method by using its syntax, parameters, and usage to demonstrate how it returns a new DataFrame with the pivoted data.

Key Points –

  • The pivot() function in Polars is used to reshape data, turning unique values from a specified column into new columns, with aggregation of values performed as needed.
  • The index parameter specifies the column(s) whose unique values will be used as the row index in the pivoted DataFrame.
  • The columns parameter defines the column(s) whose unique values will become the new columns in the pivoted DataFrame.
  • The values parameter defines the column(s) whose data will be aggregated and placed into the pivoted table.
  • You can use various aggregation functions (e.g., sum(), mean(), max(), min()) to perform calculations on the values during the pivot process.
  • The pivot() function supports multiple aggregation functions applied to different values columns, enabling complex data summarization.
  • Custom aggregation functions can be passed through the aggregate_function parameter, allowing more flexibility in summarizing data.
  • The method requires an aggregation function when multiple rows have the same combination of index and on values.

Syntax of Polars DataFrame.pivot()

Let’s know the syntax of the Polars DataFrame pivot() method.


# Syntax of pivot()
DataFrame.pivot(
    on: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
    *,
    index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
    values: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
    aggregate_function: PivotAgg | Expr | None = None,
    maintain_order: bool = True,
    sort_columns: bool = False,
    separator: str = '_',
) → DataFrame

Parameters of the Polars DataFrame.pivot()

Following are the parameters of the polars DataFrame.pivot() method.

  • on – The column(s) whose unique values will be used as the new columns in the pivoted DataFrame. Can be a single column name or a list of column names.
  • index – The column(s) that will form the rows of the pivoted DataFrame. Can be a single column name or a list of column names. If None, the index will be derived from the on columns.
  • values – The column(s) whose values will be aggregated and pivoted into new columns. It can be a single column name or a list of column names. If None, all non-index and non-pivot columns will be used.
  • aggregate_function – The aggregation function to apply when multiple rows share the same combination of index and on values. Can be a built-in aggregation function (e.g., "sum", "mean", "count") or a custom Expr. If None, no aggregation is applied (and the values column must be unique for each indexon combination).
  • maintain_order – Boolean flag that determines whether to maintain the original order of the rows in the pivoted DataFrame. Default is True.
  • sort_columns – Boolean flag that determines whether to sort the new columns alphabetically. Default is False.
  • separator – String used to separate multi-level column names (if there are multiple columns in on). Default is '_'.

Return Value

This function returns a new DataFrame with the pivoted data.

Usage of Polars DataFrame.pivot()

The pivot() method reshapes a DataFrame by transforming unique values from one or more columns into separate columns, and optionally aggregates data based on a specified index and columns.

Now, let’s create a Polars DataFrame.


import polars as pl

# Create a DataFrame
df = pl.DataFrame({'Student Names' : ['Jenny', 'Singh', 'Charles', 'Richard', 'Veena'],
                   'Category' : ['Online', 'Offline', 'Offline', 'Offline', 'Online'],
                   'Gender' : ['Female', 'Male', 'Male', 'Male', 'Female'],
                  'Courses': ['Java', 'Spark', 'PySpark','Hadoop','C'],
                   'Fee': [15000, 17000, 27000, 29000, 12000],
                   'Discount': [1100, 800, 1000, 1600, 600]})
print("Original DataFrame:\n", df)

Yields below output.

polars pivot

To create a basic pivot table using the Gender and Courses columns in the Polars DataFrame, we can specify the index as Gender and the on column as Courses.


# Pivot by Gender and Courses
df2 = df.pivot("Courses", index="Gender", values="Fee")
print(df2)

# Using DataFrame.Pivot()
df2 = df.pivot(on="Courses", index="Gender", values="Fee")
print(df2)

Here,

  • on="Courses": This will pivot the DataFrame based on the unique values in the Courses column.
  • index="Gender": This will use the Gender column as the index (rows).
  • values="Fee": This specifies that the Fee column values will be aggregated.
polars pivot

Pivot with Multiple Indexes

Alternatively, to create a pivot table in Polars using multiple indexes, you can specify multiple columns in the index parameter. This allows you to group data by more than one column, with the values spread across the pivoted columns.


# Pivot the DataFrame with multiple indexes ('Gender' and 'Category')
df2 = df.pivot(
    on="Courses",  
    index=["Gender", "Category"],  
    values="Fee", 
    aggregate_function="sum"  
)
print(df2)

# Output:
# shape: (2, 7)
┌────────┬──────────┬───────┬───────┬─────────┬────────┬───────┐
│ Gender ┆ Category ┆ Java  ┆ Spark ┆ PySpark ┆ Hadoop ┆ C     │
│ ---    ┆ ---      ┆ ---   ┆ ---   ┆ ---     ┆ ---    ┆ ---   │
│ str    ┆ str      ┆ i64   ┆ i64   ┆ i64     ┆ i64    ┆ i64   │
╞════════╪══════════╪═══════╪═══════╪═════════╪════════╪═══════╡
│ Female ┆ Online   ┆ 15000 ┆ null  ┆ null    ┆ null   ┆ 12000 │
│ Male   ┆ Offline  ┆ null  ┆ 17000 ┆ 27000   ┆ 29000  ┆ null  │
└────────┴──────────┴───────┴───────┴─────────┴────────┴───────┘

Here,

  • on="Courses": The DataFrame is pivoted based on the Courses column.
  • index=["Gender", "Category"]: Multiple columns, Gender and Category, are used as the row index (creating a multi-level index).
  • values="Fee": The Fee column values are aggregated.
  • aggregate_function="sum": The sum is used to aggregate the Fee values for each combination of Gender and Category and each unique course.

Pivot Without Aggregation

The pivot() method requires an aggregation function to handle cases where multiple rows correspond to the same combination of index and on. However, if aggregation is not necessary, you can create a pivot table without aggregation by using grouping and reshaping operations to transform the data into a wide-format DataFrame.


# Convert to wide format without aggregation
result = df.pivot(
    values="Fee",
    index="Category",
    on="Courses",
    aggregate_function=None  
)
print(result)

# Output:
# shape: (2, 6)
┌──────────┬───────┬───────┬─────────┬────────┬───────┐
│ Category ┆ Java  ┆ Spark ┆ PySpark ┆ Hadoop ┆ C     │
│ ---      ┆ ---   ┆ ---   ┆ ---     ┆ ---    ┆ ---   │
│ str      ┆ i64   ┆ i64   ┆ i64     ┆ i64    ┆ i64   │
╞══════════╪═══════╪═══════╪═════════╪════════╪═══════╡
│ Online   ┆ 15000 ┆ null  ┆ null    ┆ null   ┆ 12000 │
│ Offline  ┆ null  ┆ 17000 ┆ 27000   ┆ 29000  ┆ null  │
└──────────┴───────┴───────┴─────────┴────────┴───────┘

Here,

  • values="Fee": The column whose values are used in the pivot.
  • index="Category": Specifies the column for the row index.
  • on="Courses": Columns are created based on unique Courses values.
  • aggregate_function=None: Indicates no aggregation is performed. The function will raise an error if multiple rows exist for a combination.

Custom Aggregation Function using polars.element()

You can pass a custom aggregation function using polars.element() when working with the pivot() function. This allows you to apply any element-wise function, such as tanh(), sin(), log(), or any other function, before applying an aggregation (like mean, sum, etc.).

In this case, we will use a custom aggregation function, specifically applying the tanh() transformation to the values first and then aggregating them (e.g., by mean()).


# Pivot with a custom aggregation function 
# Using tanh transformation
result = df.pivot(
    values=["Fee", "Discount"],
    index="Gender",
    on="Courses",
    aggregate_function=pl.element().tanh().mean()
)
print(result)

# Output:
# shape: (2, 11)
┌────────┬──────────┬───────────┬────────────┬───┬────────────┬────────────┬───────────┬───────────┐
│ Gender ┆ Fee_Java ┆ Fee_Spark ┆ Fee_PySpar ┆ … ┆ Discount_S ┆ Discount_P ┆ Discount_ ┆ Discount_ │
│ ---    ┆ ---      ┆ ---       ┆ k          ┆   ┆ park       ┆ ySpark     ┆ Hadoop    ┆ C         │
│ str    ┆ f64      ┆ f64       ┆ ---        ┆   ┆ ---        ┆ ---        ┆ ---       ┆ ---       │
│        ┆          ┆           ┆ f64        ┆   ┆ f64        ┆ f64        ┆ f64       ┆ f64       │
╞════════╪══════════╪═══════════╪════════════╪═══╪════════════╪════════════╪═══════════╪═══════════╡
│ Female ┆ 1.0      ┆ null      ┆ null       ┆ … ┆ null       ┆ null       ┆ null      ┆ 1.0       │
│ Male   ┆ null     ┆ 1.0       ┆ 1.0        ┆ … ┆ 1.0        ┆ 1.0        ┆ 1.0       ┆ null      │
└────────┴──────────┴───────────┴────────────┴───┴────────────┴────────────┴───────────┴───────────┘

Here,

  • pl.element().tanh(): This applies the tanh() transformation element-wise to the values of the columns.
  • mean(): After applying the tanh() transformation, we then compute the mean of the transformed values for each Gender and Courses combination.

Handling Missing Values in Polars pivot()

When performing a pivot operation in Polars, missing values may arise if certain combinations of index and on columns do not exist in the original DataFrame.


# Pivot the table
result = df.pivot(
    values="Fee",
    index="Gender",
    on="Courses",
    aggregate_function="mean"  
)

# Replace missing values with 0
result_filled = result.fill_nan(0)
print("Pivot Result (Missing Values Replaced):\n", result_filled)

# Output:
# Pivot Result (Missing Values Replaced):
# shape: (2, 6)
┌────────┬─────────┬─────────┬─────────┬─────────┬─────────┐
│ Gender ┆ Java    ┆ Spark   ┆ PySpark ┆ Hadoop  ┆ C       │
│ ---    ┆ ---     ┆ ---     ┆ ---     ┆ ---     ┆ ---     │
│ str    ┆ f64     ┆ f64     ┆ f64     ┆ f64     ┆ f64     │
╞════════╪═════════╪═════════╪═════════╪═════════╪═════════╡
│ Female ┆ 15000.0 ┆ 0.0     ┆ 0.0     ┆ 0.0     ┆ 12000.0 │
│ Male   ┆ 0.0     ┆ 17000.0 ┆ 27000.0 ┆ 29000.0 ┆  0.0    │
└────────┴─────────┴─────────┴─────────┴─────────┴─────────┘

Here,

  • fill_nan(0): This replaces all NaN values in the resulting DataFrame with 0.

Conclusion

In this article, I have explained the Polars DataFrame pivot() method by using syntax, parameters, and usage. I also demonstrated how to create a spreadsheet-style pivot table as a DataFrame with detailed, well-explained examples.

Happy Learning!!

References