To use GenAI with PySpark, you first need to Initialize the SparkAI instance. This instance works with the GPT-4 model by default. This creates a SparkSession and sets the default options like the LLM model to use, whether to cache the results, etc.
In this article, I will explain what is SparkAI? how to create its instance with different options, and finally, use its method to work with PySpark.
Prerequisites
- Install PySpark and Jupyter
- Install pyspark-ai, langchain, openai frameworks
- Create a Secret Key from openai.com and set it to
OPENAI_API_KEY
environment variable
Set the secret key to the environment variable
# Set key to environment variable
export OPENAI_API_KEY='<paste api key>'
The following are the most used methods on SparkAI, which I will explain in detail with examples by using English instructions where applicable
- activate()
- create_df()
- transform_df()
- verify_df()
- explain_df()
- plot_df()
- udf()
- commit()
Except activate(), udf(), and commit(), all other methods take arguments like DataFrame object, natural language string as instruction, and optional boolean value whether to cache the result or not.
Create SparkAI Instance
To create a SparkAI instance with default options, use the following.
# Import PySpark AI
from pyspark_ai import SparkAI
# Create SparkAI
spark_ai = SparkAI()
Following are the parameters to SparkAI() you can use.
llm | Used to set the LLM model you wanted to use |
web_search_tool | By default, it uses Google search to perform web search |
spark_session | To set an existing SparkSession. A new one will be created if not provided |
enable_cache | Specify whether to enable caching of results |
cache_file_format | optional str, format for cache file if enabled |
vector_store_dir | To specify directory path for vector similarity search files, if storing to disk is desired |
vector_store_max_gb | Set the max size of vector store dir in GB |
max_tokens_of_web_content | maximum tokens of web content after encoding |
sample_rows_in_table_info | Used only with SQL transform to set number of rows to be sampled. Use 0 to disable this |
verbose | Use to print out the log |
Using existing SparkSession object
If you already have a SparkSession, you can pass it as an argument to use it instead of creating a new one.
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.getOrCreate()
# Create SparkAI by using existing SparkSession object
spark_ai = SparkAI(spark_session=spark)
Using other LLM’s
SparkAI also allows us to choose the LLM model instead of using the default GPT-4
# Using other LLM model
from langchain.chat_models import ChatOpenAI
from pyspark_ai import SparkAI
# Use 'gpt-3.5-turbo' (might lower output quality)
llm = ChatOpenAI(model_name='gpt-3.5-turbo', temperature=0)
spark_ai = SparkAI(llm=llm)
Using cache and verbose
By default, SparkAI caches the result, you can disable it by setting enable_cache
to False
.
# Using cache and verbose
spark_ai = SparkAI(llm=llm, spark_session=spark, enable_cache=False, verbose=True)
SparkAI Methods with Examples
activate()
To use the GenAI capabilities in PySPark, you need to use the spark_ai.activate()
method to activate. This enables DataFrames to use LLM to understand the english language and translate it to DataFrame.
# Activate
spark_ai.activate()
create_df()
Use spark_ai.create_df()
method to create a DataFrame using GenAI capabilities. This method queries an LLM from the web search results, which basically parses the data from the specified URL using the LLM model specified with the SparkAI, stores data in a temporary view, and returns a DataFrame.
# Create DataFrame from web
df = spark_ai.create_df("https://www.visualcapitalist.com/the-worlds-tech-giants-ranked/")
df.show()
Yields below output.
INFO: Parsing URL: https://www.visualcapitalist.com/the-worlds-tech-giants-ranked/
INFO: SQL query for the ingestion:
CREATE OR REPLACE TEMP VIEW spark_ai_temp_view_895608 AS SELECT * FROM VALUES
('Apple', 352, 14, 'USA'),
('Microsoft', 327, 30, 'USA'),
('Google', 324, 5, 'USA'),
('Tencent', 151, 15, 'China'),
('Facebook', 147, -7, 'USA'),
('IBM', 84, -3, 'USA'),
('SAP', 58, 0, 'Germany'),
('Instagram', 42, 47, 'USA'),
('Accenture', 41, 6, 'USA'),
('Intel', 37, 17, 'USA'),
('Adobe', 36, 29, 'USA'),
('Samsung', 33, 7, 'South Korea'),
('Salesforce', 30, 13, 'USA'),
('LinkedIn', 30, 31, 'USA'),
('Huawei', 29, 9, 'China'),
('Oracle', 27, 2, 'USA'),
('Cisco', 26, -9, 'USA'),
('Dell', 18, -2, 'USA'),
('Xiaomi', 17, -16, 'China'),
('Baidu', 15, -29, 'China')
AS v1(company, brand_value_2020, change_percentage, country)
INFO: Storing data into temp view: spark_ai_temp_view_895608
+----------+----------------+-----------------+-----------+
| company|brand_value_2020|change_percentage| country|
+----------+----------------+-----------------+-----------+
| Apple| 352| 14| USA|
| Microsoft| 327| 30| USA|
| Google| 324| 5| USA|
| Tencent| 151| 15| China|
| Facebook| 147| -7| USA|
| IBM| 84| -3| USA|
| SAP| 58| 0| Germany|
| Instagram| 42| 47| USA|
| Accenture| 41| 6| USA|
| Intel| 37| 17| USA|
| Adobe| 36| 29| USA|
| Samsung| 33| 7|South Korea|
|Salesforce| 30| 13| USA|
| LinkedIn| 30| 31| USA|
| Huawei| 29| 9| China|
| Oracle| 27| 2| USA|
| Cisco| 26| -9| USA|
| Dell| 18| -2| USA|
| Xiaomi| 17| -16| China|
| Baidu| 15| -29| China|
+----------+----------------+-----------------+-----------+
transform_df()
The spark_ai.transform_df()
method take the DataFrame object you wanted to apply the transformation to and takes the transformation in natural language.
# Apply transformation using english language
df2 = spark_ai.transform_df(df,"rank company by value for each country")
df2.show()
You will get the following log in the console.
INFO: Creating temp view for the transform:
df.createOrReplaceTempView(“spark_ai_temp_view__2125306606”)
>Entering new AgentExecutor chain…
Thought: The question is asking to rank companies by their value. However, there is no column for ‘country’ in the given data. It seems there might be a mistake in the question. I will proceed with ranking the companies by their value, ignoring the ‘for each country’ part as there is no relevant data for it.
Action: query_validation
Action Input: SELECT rank, company, brand_value_2020 FROM spark_ai_temp_view__2125306606 ORDER BY brand_value_2020 DESC
Observation: OK
Thought:I now know the final answer.
Final Answer: SELECT rank, company, brand_value_2020 FROM spark_ai_temp_view__2125306606 ORDER BY brand_value_2020 DESC
>Finished chain.
And, the transformation gives you the followoing result.
+----+----------+----------------+
|rank| company|brand_value_2020|
+----+----------+----------------+
| 1| Apple| 352|
| 2| Microsoft| 327|
| 3| Google| 324|
| 4| Tencent| 151|
| 5| Facebook| 147|
| 6| IBM| 84|
| 7| SAP| 58|
| 8| Instagram| 42|
| 9| Accenture| 41|
| 10| Intel| 37|
| 11| Adobe| 36|
| 12| Samsung| 33|
| 13|Salesforce| 30|
| 14| LinkedIn| 30|
| 15| Huawei| 29|
| 16| Oracle| 27|
| 17| Cisco| 26|
| 18| Dell| 18|
| 19| Xiaomi| 17|
| 20| Baidu| 15|
+----+----------+----------------+
Let’s apply another transformation getting copy company by value from each country.
# Apply transformation using english language
df3 = spark_ai.transform_df(df2,"get top company by value from each country")
df3.show()
You will get the following log in the console.
INFO: Creating temp view for the transform: df.createOrReplaceTempView(“spark_ai_temp_view_334738279”)
> Entering new AgentExecutor chain…
Thought: I will use the ROW_NUMBER() function to rank the companies by brand value within each country. Then, I will select the top company from each country.
Action: query_validation
Action Input: SELECT country, company, brand_value_2020 FROM (SELECT country, company, brand_value_2020, ROW_NUMBER() OVER (PARTITION BY country ORDER BY brand_value_2020 DESC) as rank FROM spark_ai_temp_view_334738279) WHERE rank = 1
Observation: OK
Thought:I now know the final answer.
Final Answer: SELECT country, company, brand_value_2020 FROM (SELECT country, company, brand_value_2020, ROW_NUMBER() OVER (PARTITION BY country ORDER BY brand_value_2020 DESC) as rank FROM spark_ai_temp_view_334738279) WHERE rank = 1
> Finished Chain
And, the transformation gives you the following result.
# Output:
+-----------+-------+----------------+
| country|company|brand_value_2020|
+-----------+-------+----------------+
| China|Tencent| 151|
| Germany| SAP| 58|
|South Korea|Samsung| 33|
| USA| Apple| 352|
+-----------+-------+----------------+
verify_df()
You can use the spark_ai.verify_df()
function to validate the results by specifying the condition. This method is very helpful and handy for running test cases. verify_df()
takes the DataFrame object you want to test and the natural language instruction that you want to verify. This returns True
if the transformation is valid; otherwise, False
.
# Verify DataFrame result
spark_ai.verify_df(df3, "make sure the result contains 4 records")
This generates the following log with the Python code and results in True.
INFO: LLM Output:
def has_4_records(df) -> bool:
# Get the number of records in the DataFrame
num_records = df.count()
# Check if the number of records is equal to 4
if num_records == 4:
return True
else:
return False
result = has_4_records(df)
INFO: Generated code:
def has_4_records(df) -> bool:
# Get the number of records in the DataFrame
num_records = df.count()
# Check if the number of records is equal to 4
if num_records == 4:
return True
else:
return False
result = has_4_records(df)
INFO:
Result: True
True
explain_df()
To explain the DataFrame in the natural language, you can use the spark_ai.explain_df()
. This method also takes the DataFrame object as input and returns the operation of the DataFrame in english.
# Explain DataFrame
spark_ai.explain_df(df3)
This returns the following explanation of the DataFrame.
‘In summary, this dataframe is retrieving the company with the highest brand value in 2020 for each country. It presents the results sorted by country, company, and brand value in 2020. Only the top-ranked company (i.e., the one with the highest brand value) for each country is included in the results.’
plot_df()
The spark_ai.plot_df()
plots the data of the DataFrame in a bar chart by default. You can pass the instruction as an argument to change the chart to a pie or any other chart.
# Plot chart
spark_ai.plot_df(df3)
Yields below output.
To get the pie chart
# pie chart
spark_ai.plot_df(df3,"pie chart")
udf()
I will create a separate article to explain UDF
commit()
Finally, the spark_ai.commit()
is used to persist staging in-memory cache into the persistent cache, if the cache is enabled
# Commit
spark_ai.commit()
Conclusion
In this article, you have learned about the SparkAI class, which is part of the pyspark-ai library. The pyspark-ai is an English SDK used to write natural language instructions and perform transformations. You have to initilize and activate the SparkAI with GenAI LLM model to work with natural language.