Apache Spark Samples

Apache Spark big data processing framework with PySpark examples, DataFrames, SQL, and machine learning operations

💻 PySpark DataFrame Operations python

🟢 simple

Essential DataFrame operations including creation, transformation, filtering, and basic analytics

# PySpark DataFrame Basic Operations

from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr, when, lit, udf
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("BasicDataFrameOperations") \
    .config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \
    .getOrCreate()

# Create sample data
data = [
    ("Alice", 34, "Engineering"),
    ("Bob", 45, "Sales"),
    ("Charlie", 29, "Engineering"),
    ("Diana", 31, "Marketing"),
    ("Eve", 38, "Sales")
]

# Define schema
schema = StructType([
    StructField("name", StringType(), True),
    StructField("age", IntegerType(), True),
    StructField("department", StringType(), True)
])

# Create DataFrame
df = spark.createDataFrame(data, schema)
df.show()

# Basic operations
df.printSchema()
df.describe().show()

# Filtering
engineers = df.filter(col("department") == "Engineering")
engineers.show()

# Transformations
df_with_salary = df.withColumn("salary", col("age") * 1000)
df_with_salary.show()

# Aggregations
dept_stats = df.groupBy("department").agg({
    "age": "avg",
    "name": "count"
})
dept_stats.show()

# SQL operations
df.createOrReplaceTempView("employees")
result = spark.sql("SELECT department, COUNT(*) as count FROM employees GROUP BY department")
result.show()

💻 Spark SQL Advanced Queries python

🟡 intermediate

Advanced SQL operations including window functions, joins, subqueries, and complex analytics

# Spark SQL Advanced Operations

from pyspark.sql import SparkSession, Window
from pyspark.sql.functions import *

spark = SparkSession.builder \
    .appName("AdvancedSparkSQL") \
    .getOrCreate()

# Create sample tables
employees_data = [
    (1, "Alice", 34, "Engineering", 120000),
    (2, "Bob", 45, "Sales", 95000),
    (3, "Charlie", 29, "Engineering", 110000)
]

sales_data = [
    (101, 1, "ProductA", 25000, "2023-01-05"),
    (102, 2, "ProductB", 18000, "2023-01-06"),
    (103, 1, "ProductC", 32000, "2023-01-08")
]

# Create DataFrames
employees = spark.createDataFrame(employees_data, ["id", "name", "age", "department", "salary"])
sales = spark.createDataFrame(sales_data, ["sale_id", "employee_id", "product", "amount", "sale_date"])

# Create temporary views
employees.createOrReplaceTempView("employees")
sales.createOrReplaceTempView("sales")

# Window functions
window_spec = Window.partitionBy("department").orderBy(col("salary").desc())
employees.withColumn("rank", rank().over(window_spec)).show()

# Complex joins
result = spark.sql("""
    SELECT e.name, e.department, s.product, s.amount
    FROM employees e
    JOIN sales s ON e.id = s.employee_id
    ORDER BY s.amount DESC
""")
result.show()

# Pivot table
pivot_result = spark.sql("""
    SELECT *
    FROM (
        SELECT department, salary
        FROM employees
    )
    PIVOT (
        AVG(salary)
        FOR department IN ('Engineering', 'Sales', 'Marketing')
    )
""")
pivot_result.show()

# Time series analysis
time_series = sales.withColumn("sale_date", to_date("sale_date"))
time_series.createOrReplaceTempView("time_series_sales")

monthly_sales = spark.sql("""
    SELECT
        DATE_TRUNC('month', sale_date) as month,
        SUM(amount) as total_sales,
        COUNT(*) as sale_count
    FROM time_series_sales
    GROUP BY DATE_TRUNC('month', sale_date)
    ORDER BY month
""")
monthly_sales.show()

💻 Spark Structured Streaming python

🔴 complex

Real-time data streaming with windowed aggregations, watermarking, and output modes

# Spark Structured Streaming

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import StructType, StringType, IntegerType, TimestampType
import json
import time

spark = SparkSession.builder \
    .appName("StructuredStreamingExample") \
    .getOrCreate()

# Define schema for incoming data
schema = StructType([
    StructField("event_time", TimestampType(), True),
    StructField("user_id", StringType(), True),
    StructField("event_type", StringType(), True),
    StructField("page_url", StringType(), True),
    StructField("session_id", StringType(), True),
    StructField("device_type", StringType(), True),
    StructField("location", StringType(), True),
    StructField("revenue", DoubleType(), True)
])

# Read from socket source (for demonstration)
lines = spark \
    .readStream \
    .format("socket") \
    .option("host", "localhost") \
    .option("port", 9999) \
    .load()

# Parse JSON data
events = lines \
    .select(from_json(col("value").cast("string"), schema).alias("data")) \
    .select("data.*")

# Windowed aggregations with watermark
windowed_counts = events \
    .withWatermark("event_time", "10 minutes") \
    .groupBy(
        window(col("event_time"), "5 minutes"),
        col("event_type")
    ) \
    .count() \
    .orderBy("window")

# Complex streaming analytics
analytics_query = events \
    .withWatermark("event_time", "1 hour") \
    .groupBy(
        window(col("event_time"), "30 seconds", "10 seconds"),
        col("device_type"),
        col("location")
    ) \
    .agg(
        count("*").alias("event_count"),
        sum("revenue").alias("total_revenue"),
        avg("revenue").alias("avg_revenue"),
        countDistinct("user_id").alias("unique_users"),
        collect_set("event_type").alias("event_types")
    ) \
    .filter(col("event_count") > 10) \
    .orderBy(desc("total_revenue"))

# Start the streaming query
console_query = analytics_query \
    .writeStream \
    .outputMode("complete") \
    .format("console") \
    .option("truncate", "false") \
    .trigger(processingTime="30 seconds") \
    .start()

console_query.awaitTermination()

# Save to parquet
parquet_query = events \
    .writeStream \
    .outputMode("append") \
    .format("parquet") \
    .option("path", "/tmp/spark_streaming_output") \
    .option("checkpointLocation", "/tmp/spark_checkpoint") \
    .trigger(processingTime="1 minute") \
    .start()

parquet_query.awaitTermination()

💻 MLlib Machine Learning python

🔴 complex

Machine learning operations using MLlib including classification, regression, clustering, and feature engineering

# Spark MLlib Machine Learning

from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StringIndexer, StandardScaler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.regression import LinearRegression
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml import Pipeline

spark = SparkSession.builder \
    .appName("MLlibExample") \
    .getOrCreate()

# Classification Example
# Load sample data
classification_data = spark.read.csv("iris.csv", header=True, inferSchema=True)

# Prepare features
feature_cols = ["sepal_length", "sepal_width", "petal_length", "petal_width"]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

# Convert labels
label_indexer = StringIndexer(inputCol="species", outputCol="label")

# Scale features
scaler = StandardScaler(inputCol="features", outputCol="scaled_features")

# Create model
lr = LogisticRegression(featuresCol="scaled_features", labelCol="label")

# Build pipeline
pipeline = Pipeline(stages=[assembler, scaler, label_indexer, lr])

# Split data
train_data, test_data = classification_data.randomSplit([0.7, 0.3], seed=42)

# Train model
model = pipeline.fit(train_data)

# Make predictions
predictions = model.transform(test_data)

# Evaluate model
evaluator = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
auc = evaluator.evaluate(predictions)
print(f"AUC: {auc}")

# Random Forest
rf = RandomForestClassifier(featuresCol="scaled_features", labelCol="label", numTrees=10)
rf_pipeline = Pipeline(stages=[assembler, scaler, label_indexer, rf])
rf_model = rf_pipeline.fit(train_data)
rf_predictions = rf_model.transform(test_data)

# Regression Example
# Generate regression data
from pyspark.sql.functions import rand, when, col
regression_data = classification_data \
    .withColumn("target", (col("sepal_length") * col("petal_length")) + rand() * 5)

# Linear regression
assembler_reg = VectorAssembler(inputCols=feature_cols, outputCol="features")
lr_reg = LinearRegression(featuresCol="features", labelCol="target")
pipeline_reg = Pipeline(stages=[assembler_reg, lr_reg])

train_reg, test_reg = regression_data.randomSplit([0.7, 0.3], seed=42)
reg_model = pipeline_reg.fit(train_reg)
reg_predictions = reg_model.transform(test_reg)

reg_evaluator = RegressionEvaluator(labelCol="target", metricName="rmse")
rmse = reg_evaluator.evaluate(reg_predictions)
print(f"RMSE: {rmse}")

# Clustering Example
kmeans = KMeans(featuresCol="features", k=3, seed=42)
kmeans_pipeline = Pipeline(stages=[assembler, kmeans])
kmeans_model = kmeans_pipeline.fit(classification_data)

# Get cluster centers
centers = kmeans_model.stages[-1].clusterCenters()
print("Cluster Centers:")
for i, center in enumerate(centers):
    print(f"Cluster {i}: {center}")

# Assign clusters
cluster_predictions = kmeans_model.transform(classification_data)
cluster_predictions.select("species", "prediction").show()

# Feature Engineering
from pyspark.ml.feature import PCA, PolynomialExpansion

# PCA for dimensionality reduction
pca = PCA(inputCol="features", outputCol="pca_features", k=2)
pca_pipeline = Pipeline(stages=[assembler, pca])
pca_model = pca_pipeline.fit(classification_data)
pca_result = pca_model.transform(classification_data)

# Polynomial features
poly_expansion = PolynomialExpansion(inputCol="features", outputCol="poly_features", degree=2)
poly_pipeline = Pipeline(stages=[assembler, poly_expansion])
poly_model = poly_pipeline.fit(classification_data)
poly_result = poly_model.transform(classification_data)

print("PCA Explained Variance Ratio:", pca_model.stages[-1].explainedVariance)
print("Original features shape:", len(feature_cols))
print("PCA features shape:", len(pca_model.stages[-1].explainedVariance))

💻 Performance Optimization python

🔴 complex

Spark performance tuning techniques including partitioning, caching, broadcast joins, and query optimization

# Spark Performance Optimization

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.storagelevel import StorageLevel
import time

# Configure Spark for performance
spark = SparkSession.builder \
    .appName("PerformanceOptimization") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .getOrCreate()

# Generate large dataset for testing
def generate_large_dataset(num_rows=1000000):
    import random
    from datetime import datetime, timedelta

    departments = ["Engineering", "Sales", "Marketing", "Finance", "HR"]
    products = ["ProductA", "ProductB", "ProductC", "ProductD", "ProductE"]

    start_date = datetime(2023, 1, 1)

    data = []
    for i in range(num_rows):
        event_date = start_date + timedelta(days=random.randint(0, 365))
        data.append((
            f"user_{random.randint(1, 100000)}",
            random.choice(departments),
            random.choice(products),
            random.randint(10, 1000),
            random.uniform(100, 10000),
            event_date
        ))

    return data

# Create large DataFrame
large_data = generate_large_dataset(1000000)
schema = StructType([
    StructField("user_id", StringType(), True),
    StructField("department", StringType(), True),
    StructField("product", StringType(), True),
    StructField("quantity", IntegerType(), True),
    StructField("price", DoubleType(), True),
    StructField("event_date", DateType(), True)
])

df = spark.createDataFrame(large_data, schema)

# 1. Partitioning optimization
print("=== Partitioning Optimization ===")
start_time = time.time()

# Repartition by department for better parallelism
df_partitioned = df.repartition("department")

# Cache the partitioned DataFrame
df_partitioned.cache()
df_partitioned.count()  # Materialize the cache

partition_time = time.time() - start_time
print(f"Partitioning and caching took: {partition_time:.2f} seconds")

# 2. Broadcast join optimization
print("\\n=== Broadcast Join Optimization ===")

# Small lookup table
lookup_data = [
    ("Engineering", 1),
    ("Sales", 2),
    ("Marketing", 3),
    ("Finance", 4),
    ("HR", 5)
]

lookup_df = spark.createDataFrame(lookup_data, ["department", "dept_id"])

# Configure broadcast join threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10MB")

# Perform broadcast join
start_time = time.time()
broadcast_join = df_partitioned.join(broadcast(lookup_df), "department")
broadcast_join.count()
broadcast_time = time.time() - start_time
print(f"Broadcast join took: {broadcast_time:.2f} seconds")

# 3. Query optimization
print("\\n=== Query Optimization ===")

# Enable adaptive query execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")

# Optimized aggregation query
start_time = time.time()
optimized_query = df_partitioned \
    .filter(col("price") > 500) \
    .groupBy("department", "product") \
    .agg(
        sum("quantity").alias("total_quantity"),
        sum("price").alias("total_revenue"),
        avg("price").alias("avg_price"),
        count("*").alias("transaction_count")
    ) \
    .filter(col("total_revenue") > 10000)

result = optimized_query.collect()
query_time = time.time() - start_time
print(f"Optimized query took: {query_time:.2f} seconds")

# 4. Caching strategies
print("\\n=== Caching Strategies ===")

# Cache with different storage levels
df_cached_memory = df.cache()  # MEMORY_ONLY
df_cached_memory_and_disk = df.persist(StorageLevel.MEMORY_AND_DISK)

# Warm up cache
df_cached_memory.count()

# Measure cache performance
start_time = time.time()
for i in range(10):
    df_cached_memory.groupBy("department").count().collect()
cached_time = time.time() - start_time
print(f"10 cached aggregations took: {cached_time:.2f} seconds")

# 5. DataFrame API vs SQL performance
print("\\n=== DataFrame API vs SQL Performance ===")

# DataFrame API
start_time = time.time()
df_result_api = df_partitioned \
    .groupBy("department") \
    .agg(
        sum("quantity").alias("total_quantity"),
        sum("price").alias("total_revenue")
    ) \
    .orderBy(desc("total_revenue"))
api_result = df_result_api.collect()
api_time = time.time() - start_time

# SQL
df_partitioned.createOrReplaceTempView("transactions")
start_time = time.time()
sql_result = spark.sql("""
    SELECT
        department,
        SUM(quantity) as total_quantity,
        SUM(price) as total_revenue
    FROM transactions
    GROUP BY department
    ORDER BY total_revenue DESC
""").collect()
sql_time = time.time() - start_time

print(f"DataFrame API time: {api_time:.2f} seconds")
print(f"SQL time: {sql_time:.2f} seconds")

# 6. Memory management
print("\\n=== Memory Management ===")

# Get Spark UI metrics
sc = spark.sparkContext
print(f"Executor memory: {spark.conf.get('spark.executor.memory')}")
print(f"Driver memory: {spark.conf.get('spark.driver.memory')}")
print(f"Shuffle partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")

# Clear cache to free memory
df_cached_memory.unpersist()
df_cached_memory_and_disk.unpersist()

# 7. Best practices checklist
print("\\n=== Performance Best Practices ===")
best_practices = """
1. Use appropriate partitioning strategy
2. Cache frequently used DataFrames
3. Use broadcast joins for small tables
4. Enable adaptive query execution
5. Filter data early in the pipeline
6. Use columnar file formats (Parquet, ORC)
7. Optimize shuffle operations
8. Monitor Spark UI for bottlenecks
9. Use appropriate memory settings
10. Consider using Kryo serializer
"""

print(best_practices)

# Clean up
spark.stop()