You are currently viewing Explain SparkAI and its Methods with Examples

To use GenAI with PySpark, you first need to Initialize the SparkAI instance. This instance works with the GPT-4 model by default. This creates a SparkSession and sets the default options like the LLM model to use, whether to cache the results, etc. 


In this article, I will explain what is SparkAI? how to create its instance with different options, and finally, use its method to work with PySpark.


Set the secret key to the environment variable

# Set key to environment variable
export OPENAI_API_KEY='<paste api key>'

The following are the most used methods on SparkAI, which I will explain in detail with examples by using English instructions where applicable

  • activate()
  • create_df()
  • transform_df()
  • verify_df()
  • explain_df()
  • plot_df()
  • udf()
  • commit()

Except activate(), udf(), and commit(), all other methods take arguments like DataFrame object, natural language string as instruction, and optional boolean value whether to cache the result or not.

Create SparkAI Instance

To create a SparkAI instance with default options, use the following.

# Import PySpark AI
from pyspark_ai import SparkAI

# Create SparkAI
spark_ai = SparkAI()

Following are the parameters to SparkAI() you can use.

llmUsed to set the LLM model you wanted to use
web_search_toolBy default, it uses Google search to perform web search
spark_sessionTo set an existing SparkSession. A new one will be created if not provided
enable_cacheSpecify whether to enable caching of results
cache_file_formatoptional str, format for cache file if enabled
vector_store_dirTo specify directory path for vector similarity search files,
if storing to disk is desired
vector_store_max_gbSet the max size of vector store dir in GB
max_tokens_of_web_contentmaximum tokens of web content after encoding
sample_rows_in_table_infoUsed only with SQL transform to set number of rows to be sampled. Use 0 to disable this
verboseUse to print out the log

Using existing SparkSession object

If you already have a SparkSession, you can pass it as an argument to use it instead of creating a new one.

from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.getOrCreate()

# Create SparkAI by using existing SparkSession object
spark_ai = SparkAI(spark_session=spark)

Using other LLM’s

SparkAI also allows us to choose the LLM model instead of using the default GPT-4

# Using other LLM model
from langchain.chat_models import ChatOpenAI
from pyspark_ai import SparkAI

# Use 'gpt-3.5-turbo' (might lower output quality)
llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)

spark_ai = SparkAI(llm=llm)

Using cache and verbose

By default, SparkAI caches the result, you can disable it by setting enable_cache to False.

# Using cache and verbose
spark_ai = SparkAI(llm=llm, spark_session=spark, enable_cache=False, verbose=True)

SparkAI Methods with Examples


To use the GenAI capabilities in PySPark, you need to use the spark_ai.activate() method to activate. This enables DataFrames to use LLM to understand the english language and translate it to DataFrame.

# Activate


Use spark_ai.create_df() method to create a DataFrame using GenAI capabilities. This method queries an LLM from the web search results, which basically parses the data from the specified URL using the LLM model specified with the SparkAI, stores data in a temporary view, and returns a DataFrame.

# Create DataFrame from web
df = spark_ai.create_df("")

Yields below output.

INFO: Parsing URL:

INFO: SQL query for the ingestion:
('Apple', 352, 14, 'USA'),
('Microsoft', 327, 30, 'USA'),
('Google', 324, 5, 'USA'),
('Tencent', 151, 15, 'China'),
('Facebook', 147, -7, 'USA'),
('IBM', 84, -3, 'USA'),
('SAP', 58, 0, 'Germany'),
('Instagram', 42, 47, 'USA'),
('Accenture', 41, 6, 'USA'),
('Intel', 37, 17, 'USA'),
('Adobe', 36, 29, 'USA'),
('Samsung', 33, 7, 'South Korea'),
('Salesforce', 30, 13, 'USA'),
('LinkedIn', 30, 31, 'USA'),
('Huawei', 29, 9, 'China'),
('Oracle', 27, 2, 'USA'),
('Cisco', 26, -9, 'USA'),
('Dell', 18, -2, 'USA'),
('Xiaomi', 17, -16, 'China'),
('Baidu', 15, -29, 'China')
AS v1(company, brand_value_2020, change_percentage, country)

INFO: Storing data into temp view: spark_ai_temp_view_895608

|   company|brand_value_2020|change_percentage|    country|
|     Apple|             352|               14|        USA|
| Microsoft|             327|               30|        USA|
|    Google|             324|                5|        USA|
|   Tencent|             151|               15|      China|
|  Facebook|             147|               -7|        USA|
|       IBM|              84|               -3|        USA|
|       SAP|              58|                0|    Germany|
| Instagram|              42|               47|        USA|
| Accenture|              41|                6|        USA|
|     Intel|              37|               17|        USA|
|     Adobe|              36|               29|        USA|
|   Samsung|              33|                7|South Korea|
|Salesforce|              30|               13|        USA|
|  LinkedIn|              30|               31|        USA|
|    Huawei|              29|                9|      China|
|    Oracle|              27|                2|        USA|
|     Cisco|              26|               -9|        USA|
|      Dell|              18|               -2|        USA|
|    Xiaomi|              17|              -16|      China|
|     Baidu|              15|              -29|      China|


The spark_ai.transform_df() method take the DataFrame object you wanted to apply the transformation to and takes the transformation in natural language.

# Apply transformation using english language
df2 = spark_ai.transform_df(df,"rank company by value for each country")

You will get the following log in the console.

INFO: Creating temp view for the transform:

>Entering new AgentExecutor chain…

>Finished chain.

And, the transformation gives you the followoing result.

|rank|   company|brand_value_2020|
|   1|     Apple|             352|
|   2| Microsoft|             327|
|   3|    Google|             324|
|   4|   Tencent|             151|
|   5|  Facebook|             147|
|   6|       IBM|              84|
|   7|       SAP|              58|
|   8| Instagram|              42|
|   9| Accenture|              41|
|  10|     Intel|              37|
|  11|     Adobe|              36|
|  12|   Samsung|              33|
|  13|Salesforce|              30|
|  14|  LinkedIn|              30|
|  15|    Huawei|              29|
|  16|    Oracle|              27|
|  17|     Cisco|              26|
|  18|      Dell|              18|
|  19|    Xiaomi|              17|
|  20|     Baidu|              15|

Let’s apply another transformation getting copy company by value from each country.

# Apply transformation using english language
df3 = spark_ai.transform_df(df2,"get top company by value from each country")

You will get the following log in the console.

INFO: Creating temp view for the transform: df.createOrReplaceTempView(“spark_ai_temp_view_334738279”)

> Entering new AgentExecutor chain…

> Finished Chain

And, the transformation gives you the following result.

# Output:
|    country|company|brand_value_2020|
|      China|Tencent|             151|
|    Germany|    SAP|              58|
|South Korea|Samsung|              33|
|        USA|  Apple|             352|


You can use the spark_ai.verify_df() function to validate the results by specifying the condition. This method is very helpful and handy for running test cases. verify_df() takes the DataFrame object you want to test and the natural language instruction that you want to verify. This returns True if the transformation is valid; otherwise, False.

# Verify DataFrame result
spark_ai.verify_df(df3, "make sure the result contains 4 records")

This generates the following log with the Python code and results in True.

INFO: LLM Output:
def has_4_records(df) -> bool:
    # Get the number of records in the DataFrame
    num_records = df.count()

    # Check if the number of records is equal to 4
    if num_records == 4:
        return True
        return False

result = has_4_records(df)
INFO: Generated code:
def has_4_records(df) -> bool:
    # Get the number of records in the DataFrame
    num_records = df.count()

    # Check if the number of records is equal to 4
    if num_records == 4:
        return True
        return False

result = has_4_records(df)

Result: True


To explain the DataFrame in the natural language, you can use the spark_ai.explain_df(). This method also takes the DataFrame object as input and returns the operation of the DataFrame in english.

# Explain DataFrame

This returns the following explanation of the DataFrame.

‘In summary, this dataframe is retrieving the company with the highest brand value in 2020 for each country. It presents the results sorted by country, company, and brand value in 2020. Only the top-ranked company (i.e., the one with the highest brand value) for each country is included in the results.’


The spark_ai.plot_df() plots the data of the DataFrame in a bar chart by default. You can pass the instruction as an argument to change the chart to a pie or any other chart.

# Plot chart

Yields below output.

To get the pie chart

# pie chart
spark_ai.plot_df(df3,"pie chart")


I will create a separate article to explain UDF


Finally, the spark_ai.commit() is used to persist staging in-memory cache into the persistent cache, if the cache is enabled

# Commit


In this article, you have learned about the SparkAI class, which is part of the pyspark-ai library. The pyspark-ai is an English SDK used to write natural language instructions and perform transformations. You have to initilize and activate the SparkAI with GenAI LLM model to work with natural language.