8.8 KiB
name, description
| name | description |
|---|---|
| pyspark-patterns | PySpark best practices, TableUtilities methods, ETL patterns, logging standards, and DataFrame operations for this project. Use when writing or debugging PySpark code. |
PySpark Patterns & Best Practices
Comprehensive guide to PySpark patterns used in the Unify data migration project.
Core Principle
Always use DataFrame operations over raw SQL when possible.
TableUtilities Class Methods
Central utility class providing standardized DataFrame operations.
add_row_hash()
Add hash column for change detection and deduplication.
table_utilities = TableUtilities()
df_with_hash = table_utilities.add_row_hash(df)
save_as_table()
Standard table save with timestamp conversion and automatic filtering.
table_utilities.save_as_table(df, "database.table_name")
Features:
- Converts timestamp columns automatically
- Filters to last N years when
date_createdcolumn exists (controlled byNUMBER_OF_YEARS) - Prevents full dataset processing in local development
clean_date_time_columns()
Intelligent timestamp parsing for various date formats.
df_cleaned = table_utilities.clean_date_time_columns(df)
Deduplication Methods
Simple deduplication (all columns):
df_deduped = table_utilities.drop_duplicates_simple(df)
Advanced deduplication (specific columns, ordering):
df_deduped = table_utilities.drop_duplicates_advanced(
df,
partition_columns=["id"],
order_columns=["date_created"]
)
filter_and_drop_column()
Remove duplicate flags after processing.
df_filtered = table_utilities.filter_and_drop_column(df, "is_duplicate")
generate_deduplicate()
Compare with existing table and identify new/changed records.
df_new = table_utilities.generate_deduplicate(df, "database.existing_table")
generate_unique_ids()
Generate auto-incrementing unique identifiers.
df_with_id = table_utilities.generate_unique_ids(df, "unique_id_column_name")
ETL Class Pattern
All silver and gold transformations follow this standardized pattern:
class TableName:
def __init__(self, bronze_table_name: str):
self.bronze_table_name = bronze_table_name
self.silver_database_name = f"silver_{self.bronze_table_name.split('.')[0].split('_')[-1]}"
self.silver_table_name = self.bronze_table_name.split(".")[-1].replace("b_", "s_")
# Execute ETL pipeline
self.extract_sdf = self.extract()
self.transform_sdf = self.transform()
self.load()
@synapse_error_print_handler
def extract(self) -> DataFrame:
"""Extract data from source tables."""
logger.info(f"Extracting from {self.bronze_table_name}")
df = spark.table(self.bronze_table_name)
logger.success(f"Extracted {df.count()} records")
return df
@synapse_error_print_handler
def transform(self) -> DataFrame:
"""Transform data according to business rules."""
logger.info("Starting transformation")
# Apply transformations
transformed_df = self.extract_sdf.filter(...).select(...)
logger.success("Transformation complete")
return transformed_df
@synapse_error_print_handler
def load(self) -> None:
"""Load data to target table."""
logger.info(f"Loading to {self.silver_database_name}.{self.silver_table_name}")
table_utilities.save_as_table(
self.transform_sdf,
f"{self.silver_database_name}.{self.silver_table_name}"
)
logger.success(f"Successfully loaded {self.silver_table_name}")
# Instantiate with exception handling
try:
TableName("bronze_database.b_table_name")
except Exception as e:
logger.error(f"Error processing TableName: {str(e)}")
raise e
Logging Standards
Use NotebookLogger (Never print())
from utilities.session_optimiser import NotebookLogger
logger = NotebookLogger()
# Log levels
logger.info("Starting process") # Informational messages
logger.warning("Potential issue detected") # Warnings
logger.error("Operation failed") # Errors
logger.success("Process completed") # Success messages
Logging Best Practices
-
Always include table/database names:
logger.info(f"Processing table {database}.{table}") -
Log at key milestones:
logger.info("Starting extraction") # ... extraction code logger.success("Extraction complete") -
Include counts and metrics:
logger.info(f"Extracted {df.count()} records from {table}") -
Error context:
logger.error(f"Failed to process {table}: {str(e)}")
Error Handling Pattern
@synapse_error_print_handler Decorator
Wrap ALL processing functions with this decorator:
from utilities.session_optimiser import synapse_error_print_handler
@synapse_error_print_handler
def extract(self) -> DataFrame:
# Your code here
return df
Benefits:
- Consistent error handling across codebase
- Automatic error logging
- Graceful error propagation
Exception Handling at Instantiation
try:
MyETLClass("source_table")
except Exception as e:
logger.error(f"Error processing MyETLClass: {str(e)}")
raise e
DataFrame Operations Patterns
Filtering
# Use col() for clarity
from pyspark.sql.functions import col
df_filtered = df.filter(col("status") == "active")
df_filtered = df.filter((col("age") > 18) & (col("country") == "AU"))
Selecting and Aliasing
from pyspark.sql.functions import col, lit
df_selected = df.select(
col("id"),
col("name").alias("person_name"),
lit("constant_value").alias("constant_column")
)
Joins
# Always use explicit join keys and type
df_joined = df1.join(
df2,
df1["id"] == df2["person_id"],
"inner" # inner, left, right, outer
)
# Drop duplicate columns after join
df_joined = df_joined.drop(df2["person_id"])
Window Functions
from pyspark.sql import Window
from pyspark.sql.functions import row_number, rank, dense_rank
window_spec = Window.partitionBy("category").orderBy(col("date").desc())
df_windowed = df.withColumn(
"row_num",
row_number().over(window_spec)
).filter(col("row_num") == 1)
Aggregations
from pyspark.sql.functions import sum, avg, count, max, min
df_agg = df.groupBy("category").agg(
count("*").alias("total_count"),
sum("amount").alias("total_amount"),
avg("amount").alias("avg_amount")
)
JDBC Connection Pattern
def get_connection_properties() -> dict:
"""Get JDBC connection properties."""
return {
"user": os.getenv("DB_USER"),
"password": os.getenv("DB_PASSWORD"),
"driver": "com.microsoft.sqlserver.jdbc.SQLServerDriver"
}
# Use for JDBC reads
df = spark.read.jdbc(
url=jdbc_url,
table="schema.table",
properties=get_connection_properties()
)
Session Management
Get Optimized Spark Session
from utilities.session_optimiser import SparkOptimiser
spark = SparkOptimiser.get_optimised_spark_session()
Reset Spark Context
table_utilities.reset_spark_context()
When to use:
- Memory issues
- Multiple Spark sessions
- After large operations
Memory Management
Caching
# Cache frequently accessed DataFrames
df_cached = df.cache()
# Unpersist when done
df_cached.unpersist()
Partitioning
# Repartition for better parallelism
df_repartitioned = df.repartition(10)
# Coalesce to reduce partitions
df_coalesced = df.coalesce(1)
Common Pitfalls to Avoid
- Don't use print() statements - Use logger methods
- Don't read entire tables without filtering - Filter early
- Don't create DataFrames inside loops - Collect and batch
- Don't use collect() on large DataFrames - Process distributedly
- Don't forget to unpersist cached DataFrames - Memory leaks
Performance Tips
- Filter early: Reduce data volume ASAP
- Use broadcast for small tables: Optimize joins
- Partition strategically: Balance parallelism
- Cache wisely: Only for reused DataFrames
- Use window functions: Instead of self-joins
Code Quality Standards
Type Hints
from pyspark.sql import DataFrame
def process_data(df: DataFrame, table_name: str) -> DataFrame:
return df.filter(col("active") == True)
Line Length
Maximum: 240 characters (not standard 88/120)
Blank Lines
No blank lines inside functions - Keep functions compact
Imports
All imports at top of file, never inside functions
from pyspark.sql import DataFrame
from pyspark.sql.functions import col, lit, when
from utilities.session_optimiser import TableUtilities, NotebookLogger