Python and Apache Spark integration
This page describes how to develop a simple ingest and search application using the AVS Python client. Aerospike Vector Search (AVS) Python client integrates AVS with Apache Spark to facilitate large-scale data ingestion or search tasks.
Prerequisites
- A Spark cluster with PySpark Python version 3.9 or later.
- A running AVS cluster that is reachable from the Spark master and worker nodes.
Develop an ingest and search application
-
Add the following snippet to the Spark cluster’s initialization script. This script installs the AVS Python client and ensures that the client is available on all nodes.
#!/bin/bash# Initialization action to install a Python package on all nodes of a Google Cloud Dataproc clusterpython3 -m pip install aerospike_vector_search -
The script runs when the Spark cluster is created.
Sample application
You can copy the following example to build the ingest and search application. The example outlines a framework for integrating the AVS Python client into large-scale data processing systems.
import argparse
from aerospike_vector_search import Indexfrom pyspark.sql import SparkSessionfrom pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType
#Usage#spark-submit test.py --host "34.41.1.43" --port 5000 --read_path "gs://avs-parquet-sample"
# schema is encoded in parquet not needed in scipt, only for understanding how data looks likeschema = StructType([ StructField("productId", StringType(), True), StructField("partnerId", StringType(), True), StructField("image_embs", ArrayType(FloatType(), containsNull=False), True), StructField("text_embs", ArrayType(FloatType(), containsNull=False), True)])
img_set = "img_set_1"img_idx = "img_idx_1"txt_idx = "text_idx_1"txt_set = "text_set_1"from aerospike_vector_search import types, Clientdef process_row(row, client): try: img_doc = {"productId": row.productId, "partnerId": row.partnerId, "image_embs": row.image_embs} txt_doc = {"productId": row.productId, "text_embs": row.text_embs} client.upsert(namespace="test", set_name=img_set, key=img_doc["productId"], record_data=img_doc) client.upsert(namespace="test", set_name=txt_set, key=txt_doc["productId"], record_data=txt_doc) except Exception as e: raise
def init_client(host, port): client = Client(seeds=types.HostPort(host=host, port=port), is_loadbalancer=True) return client
# Wait for the index to finish indexing recordsdef wait_for_indexing(index: Index): import time
vertices = 0 unmerged_recs = 0
# Wait for the index to have vertices and no unmerged records while vertices == 0 or unmerged_recs > 0: status = index.status()
vertices = status.index_healer_vertices_valid unmerged_recs = status.unmerged_record_count
time.sleep(0.5)
def process_partition(partition, host, port): client = None try: client = init_client(host, port) for row in partition: process_row(row, client) finally: if client: client.close()if __name__ == "__main__": parser = argparse.ArgumentParser(description='PySpark AVS Integration') parser.add_argument('--host', required=True, help='AVS hostname') parser.add_argument('--port', type=int, required=True, help='AVS port ') parser.add_argument('--read_path', required=True, help='GCS Path to read Parquet data')
args = parser.parse_args()
# Initialize Spark session spark = SparkSession.builder.appName("AVS python client integration with Spark").getOrCreate()
# Read Parquet data from GCS using the read_path argument df_parquet = spark.read.parquet(args.read_path)
# Initialize Aerospike Client with the command line arguments avs_client = Client(seeds=types.HostPort(host=args.host, port=args.port), is_loadbalancer=True) avs_client.index_create(namespace="test", name=img_idx, sets=img_set, vector_field="image_embs", dimensions=256) avs_client.index_create(namespace="test", name=txt_idx, sets=txt_set, vector_field="text_embs", dimensions=256)
# target the image index index = avs_client.index(namespace="test", name=img_idx)
# Process partitions process_partition_function = lambda partition: process_partition(partition, args.host, args.port) df_parquet.rdd.foreachPartition(process_partition_function)
# Fetch first row and execute vector search as an example first_row = df_parquet.first() image_embs = first_row['image_embs']
try: # wait for the index to finish indexing data wait_for_indexing(index)
# search against the image embedding results = index.vector_search(query=image_embs, include_fields=["productId", "partnerId"]) print (f"Len of vector results: {len(results)}") for result in results: print(f"{result.key} -> {result.fields}") except Exception as e: raise finally: avs_client.close() spark.stop()
Read the Docs
For details about using the Python client, visit our Read the Docs page.