数据科学示例
全面的数据科学示例,涵盖pandas、numpy、matplotlib用于数据处理、分析和可视化
💻 NumPy 基础
🟢 simple
⭐⭐
NumPy数值计算基础,包括数组操作、数学运算和线性代数
⏱️ 30 min
🏷️ numpy, data-science, numerical-computing
Prerequisites:
Python basics, Basic mathematics
# NumPy Fundamentals for Data Science
import numpy as np
import matplotlib.pyplot as plt
# 1. Creating Arrays
def array_creation():
print("=== Array Creation ===")
# From Python lists
arr1 = np.array([1, 2, 3, 4, 5])
print("1D array:", arr1)
arr2 = np.array([[1, 2, 3], [4, 5, 6]])
print("2D array:\n", arr2)
# Using NumPy functions
zeros = np.zeros((3, 4))
print("Zeros array:\n", zeros)
ones = np.ones((2, 3))
print("Ones array:\n", ones)
# Special arrays
identity = np.eye(3)
print("Identity matrix:\n", identity)
# Range and linspace
range_arr = np.arange(0, 10, 2) # Start, stop, step
print("Arange:", range_arr)
linspace = np.linspace(0, 10, 5) # Start, stop, num_points
print("Linspace:", linspace)
# Random arrays
random_uniform = np.random.random((2, 3))
print("Random uniform:\n", random_uniform)
random_normal = np.random.randn(3, 3)
print("Random normal:\n", random_normal)
# 2. Array Properties and Manipulation
def array_properties():
print("\n=== Array Properties ===")
arr = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
print("Array shape:", arr.shape)
print("Array dimensions:", arr.ndim)
print("Array size:", arr.size)
print("Data type:", arr.dtype)
# Reshaping
reshaped = arr.reshape(2, 6)
print("Reshaped array:\n", reshaped)
# Flattening
flattened = arr.flatten()
print("Flattened array:", flattened)
# Transpose
transposed = arr.T
print("Transposed array:\n", transposed)
# 3. Array Indexing and Slicing
def array_indexing():
print("\n=== Array Indexing and Slicing ===")
arr = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]])
# Indexing
print("Element at [1,2]:", arr[1, 2])
print("First row:", arr[0])
print("Second column:", arr[:, 1])
# Slicing
print("First two rows:\n", arr[:2])
print("Last two columns:\n", arr[:, 2:])
print("Sub-array (rows 0-1, cols 1-2):\n", arr[:2, 1:3])
# Boolean indexing
mask = arr > 5
print("Elements > 5:", arr[mask])
# Fancy indexing
row_indices = np.array([0, 2])
col_indices = np.array([1, 3])
print("Fancy indexing:", arr[row_indices, col_indices])
# 4. Mathematical Operations
def mathematical_operations():
print("\n=== Mathematical Operations ===")
arr1 = np.array([1, 2, 3, 4])
arr2 = np.array([5, 6, 7, 8])
# Element-wise operations
print("Addition:", arr1 + arr2)
print("Multiplication:", arr1 * arr2)
print("Division:", arr2 / arr1)
print("Power:", arr1 ** 2)
# Universal functions (ufuncs)
print("Square root:", np.sqrt(arr1))
print("Exponential:", np.exp(arr1))
print("Logarithm:", np.log(arr2))
print("Sine:", np.sin(arr1))
# Array operations
matrix_a = np.array([[1, 2], [3, 4]])
matrix_b = np.array([[5, 6], [7, 8]])
print("Matrix multiplication:\n", np.dot(matrix_a, matrix_b))
print("Element-wise multiplication:\n", matrix_a * matrix_b)
# Statistical operations
data = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
print("Mean:", np.mean(data))
print("Median:", np.median(data))
print("Standard deviation:", np.std(data))
print("Variance:", np.var(data))
print("Min:", np.min(data))
print("Max:", np.max(data))
print("Sum:", np.sum(data))
print("Product:", np.prod(data))
# 5. Linear Algebra Operations
def linear_algebra():
print("\n=== Linear Algebra Operations ===")
# Matrix operations
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
print("Matrix A:\n", A)
print("Matrix B:\n", B)
# Determinant
print("Determinant of A:", np.linalg.det(A))
# Inverse
try:
inv_A = np.linalg.inv(A)
print("Inverse of A:\n", inv_A)
except np.linalg.LinAlgError:
print("Matrix A is singular")
# Eigenvalues and eigenvectors
eigenvalues, eigenvectors = np.linalg.eig(A)
print("Eigenvalues:", eigenvalues)
print("Eigenvectors:\n", eigenvectors)
# Solving linear equations
# Ax = b
A_eq = np.array([[2, 1], [1, 3]])
b_eq = np.array([5, 6])
x = np.linalg.solve(A_eq, b_eq)
print("Solution to Ax = b:", x)
# 6. Data Aggregation and Reduction
def data_aggregation():
print("\n=== Data Aggregation ===")
# Create sample data
data = np.random.randint(0, 100, (5, 4))
print("Sample data:\n", data)
# Row-wise operations
print("Row-wise sum:", data.sum(axis=1))
print("Row-wise mean:", data.mean(axis=1))
print("Row-wise max:", data.max(axis=1))
# Column-wise operations
print("Column-wise sum:", data.sum(axis=0))
print("Column-wise mean:", data.mean(axis=0))
print("Column-wise max:", data.max(axis=0))
# Cumulative operations
print("Cumulative sum:", np.cumsum(data))
print("Cumulative product:", np.cumprod(data))
# Sorting
print("Sorted array:", np.sort(data.flatten()))
print("Sorted by rows:", np.sort(data, axis=1))
print("Sorted by columns:", np.sort(data, axis=0))
# 7. Broadcasting and Shape Manipulation
def broadcasting_examples():
print("\n=== Broadcasting Examples ===")
# Basic broadcasting
a = np.array([[1, 2, 3], [4, 5, 6]])
b = np.array([10, 20, 30])
print("Array a:\n", a)
print("Array b:", b)
print("Broadcasted addition:\n", a + b)
# More complex broadcasting
x = np.array([[1], [2], [3]]) # Shape (3,1)
y = np.array([10, 20, 30, 40]) # Shape (4,)
result = x + y # Results in (3,4) array
print("Complex broadcasting result:\n", result)
# Outer product
outer = np.outer(x.flatten(), y)
print("Outer product:\n", outer)
# 8. Performance Comparison
def performance_comparison():
print("\n=== Performance Comparison ===")
import time
# Large arrays
size = 1000000
list1 = list(range(size))
list2 = list(range(size, 2*size))
arr1 = np.arange(size)
arr2 = np.arange(size, 2*size)
# Python list operation
start = time.time()
result_list = [a + b for a, b in zip(list1, list2)]
list_time = time.time() - start
# NumPy operation
start = time.time()
result_array = arr1 + arr2
numpy_time = time.time() - start
print(f"Python list time: {list_time:.4f} seconds")
print(f"NumPy array time: {numpy_time:.4f} seconds")
print(f"Speedup: {list_time/numpy_time:.2f}x")
# 9. Real-world Data Analysis Example
def sales_data_analysis():
print("\n=== Sales Data Analysis ===")
# Simulate sales data (products x months)
np.random.seed(42)
products = ['Product_A', 'Product_B', 'Product_C', 'Product_D', 'Product_E']
months = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun']
# Random sales data
sales_data = np.random.randint(100, 1000, (5, 6))
print("Sales data (products x months):\n", sales_data)
# Basic statistics
print("\nTotal sales per product:", sales_data.sum(axis=1))
print("Total sales per month:", sales_data.sum(axis=0))
print("Best selling product:", products[np.argmax(sales_data.sum(axis=1))])
print("Best month:", months[np.argmax(sales_data.sum(axis=0))])
# Monthly growth rate
monthly_totals = sales_data.sum(axis=0)
growth_rate = (monthly_totals[1:] - monthly_totals[:-1]) / monthly_totals[:-1] * 100
print("Monthly growth rates (%):", growth_rate)
# Product performance categories
product_totals = sales_data.sum(axis=1)
avg_sales = np.mean(product_totals)
high_performers = products[product_totals > avg_sales * 1.2]
low_performers = products[product_totals < avg_sales * 0.8]
print("High performers (>120% average):", high_performers)
print("Low performers (<80% average):", low_performers)
# Run all examples
if __name__ == "__main__":
print("NumPy Data Science Examples")
print("=" * 40)
array_creation()
array_properties()
array_indexing()
mathematical_operations()
linear_algebra()
data_aggregation()
broadcasting_examples()
performance_comparison()
sales_data_analysis()
💻 Pandas 数据分析
🟡 intermediate
⭐⭐⭐
完整的pandas工作流程,用于数据操作、清洗、分析和时间序列操作
⏱️ 45 min
🏷️ pandas, data-analysis, data-science
Prerequisites:
Python basics, NumPy fundamentals, Basic statistics
# Pandas Data Analysis for Data Science
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
# Set styling
plt.style.use('default')
sns.set_palette("husl")
# 1. Creating and Loading Data
def data_creation_loading():
print("=== Data Creation and Loading ===")
# Create DataFrame from dictionary
data = {
'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'],
'age': [25, 30, 35, 28, 32],
'city': ['New York', 'London', 'Paris', 'Tokyo', 'Sydney'],
'salary': [70000, 80000, 90000, 75000, 85000],
'department': ['IT', 'Finance', 'IT', 'Marketing', 'Finance']
}
df = pd.DataFrame(data)
print("Created DataFrame:\n", df)
# Create DataFrame from lists
names = ['Product_A', 'Product_B', 'Product_C']
prices = [100, 150, 200]
quantities = [50, 30, 40]
df_products = pd.DataFrame({
'Product': names,
'Price': prices,
'Quantity': quantities
})
print("\nProducts DataFrame:\n", df_products)
# Load CSV (simulated)
# df_csv = pd.read_csv('data.csv')
# print("Loaded from CSV:\n", df_csv.head())
# Basic information
print("\nDataFrame Info:")
df.info()
print("\nStatistical Summary:")
print(df.describe())
# 2. Data Selection and Filtering
def data_selection_filtering():
print("\n=== Data Selection and Filtering ===")
# Sample data
data = {
'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve', 'Frank', 'Grace'],
'age': [25, 30, 35, 28, 32, 45, 29],
'city': ['New York', 'London', 'Paris', 'Tokyo', 'Sydney', 'London', 'Paris'],
'salary': [70000, 80000, 90000, 75000, 85000, 95000, 72000],
'department': ['IT', 'Finance', 'IT', 'Marketing', 'Finance', 'IT', 'Marketing'],
'join_date': pd.to_datetime(['2020-01-15', '2019-03-20', '2018-07-10',
'2021-02-01', '2020-06-15', '2017-11-30', '2021-08-20'])
}
df = pd.DataFrame(data)
# Column selection
print("Name and Salary columns:\n", df[['name', 'salary']])
# Row selection by position
print("\nFirst 3 rows:\n", df.head(3))
print("\nRows 2-4:\n", df.iloc[2:5])
# Row selection by label
print("\nRows with index 1-3:\n", df.loc[1:3])
# Conditional filtering
high_salary = df[df['salary'] > 80000]
print("\nEmployees with salary > 80k:\n", high_salary)
# Multiple conditions
it_dept_30plus = df[(df['department'] == 'IT') & (df['age'] >= 30)]
print("\nIT employees age 30+:\n", it_dept_30plus)
# Using isin()
target_cities = ['New York', 'London', 'Paris']
eu_employees = df[df['city'].isin(target_cities)]
print("\nEmployees in target cities:\n", eu_employees)
# String methods
print("\nNames starting with 'A':\n", df[df['name'].str.startswith('A')])
# 3. Data Cleaning and Preparation
def data_cleaning():
print("\n=== Data Cleaning and Preparation ===")
# Create messy data
messy_data = {
'name': ['Alice', 'Bob', None, 'Diana', 'Eve', ' Frank ', 'Grace'],
'age': [25, 30, 35, None, 32, 45, 29],
'salary': [70000, 80000, 90000, 75000, None, 95000, 72000],
'department': ['IT', 'Finance', 'IT', 'Marketing', 'FINANCE', 'IT', 'marketing'],
'email': ['[email protected]', 'bob@email', '[email protected]',
None, '[email protected]', '[email protected]', '[email protected]']
}
df = pd.DataFrame(messy_data)
print("Original messy data:\n", df)
# Handle missing values
print("\nMissing values count:\n", df.isnull().sum())
# Drop rows with missing names
df_clean = df.dropna(subset=['name'])
# Fill missing ages with mean
df_clean['age'] = df_clean['age'].fillna(df_clean['age'].mean())
# Fill missing salaries with median
df_clean['salary'] = df_clean['salary'].fillna(df_clean['salary'].median())
print("\nAfter handling missing values:\n", df_clean)
# Clean string data
df_clean['name'] = df_clean['name'].str.strip() # Remove whitespace
df_clean['department'] = df_clean['department'].str.title() # Title case
df_clean['email'] = df_clean['email'].str.lower() # Lowercase
# Validate email format (simple check)
email_pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
df_clean['valid_email'] = df_clean['email'].str.match(email_pattern)
print("\nAfter cleaning strings:\n", df_clean)
# Remove duplicate rows
df_clean = df_clean.drop_duplicates()
print("\nAfter removing duplicates (if any):\n", df_clean)
# 4. Data Transformation and Feature Engineering
def data_transformation():
print("\n=== Data Transformation ===")
# Sample sales data
sales_data = {
'product': ['A', 'B', 'A', 'C', 'B', 'A', 'C', 'B', 'A'],
'region': ['North', 'South', 'North', 'East', 'West', 'South', 'North', 'East', 'West'],
'sales': [1000, 1500, 1200, 800, 2000, 900, 1100, 1800, 700],
'date': pd.to_datetime(['2023-01-01', '2023-01-02', '2023-01-03',
'2023-01-04', '2023-01-05', '2023-01-06',
'2023-01-07', '2023-01-08', '2023-01-09'])
}
df = pd.DataFrame(sales_data)
# Create new features
df['sales_tax'] = df['sales'] * 0.08 # 8% tax
df['total_revenue'] = df['sales'] + df['sales_tax']
df['quarter'] = df['date'].dt.quarter
df['month'] = df['date'].dt.month
df['weekday'] = df['date'].dt.day_name()
# Categorical encoding
df['product_encoded'] = df['product'].astype('category').cat.codes
df['region_encoded'] = df['region'].astype('category').cat.codes
# One-hot encoding
df_encoded = pd.get_dummies(df, columns=['product', 'region'], prefix=['prod', 'reg'])
print("Original DataFrame:\n", df.head())
print("\nOne-hot encoded DataFrame:\n", df_encoded.head())
# Grouping and aggregation
product_stats = df.groupby('product')['sales'].agg(['mean', 'sum', 'count', 'std'])
print("\nProduct statistics:\n", product_stats)
# Pivot table
pivot_table = df.pivot_table(values='sales', index='product', columns='region', aggfunc='sum', fill_value=0)
print("\nPivot table:\n", pivot_table)
# 5. Time Series Analysis
def time_series_analysis():
print("\n=== Time Series Analysis ===")
# Create time series data
date_range = pd.date_range(start='2023-01-01', end='2023-12-31', freq='D')
np.random.seed(42)
# Simulate sales data with trend and seasonality
trend = np.linspace(100, 200, len(date_range))
seasonal = 50 * np.sin(2 * np.pi * np.arange(len(date_range)) / 30) # Monthly seasonality
noise = np.random.normal(0, 20, len(date_range))
sales = trend + seasonal + noise
sales[sales < 0] = 0 # Ensure non-negative
df_ts = pd.DataFrame({
'date': date_range,
'sales': sales
})
df_ts.set_index('date', inplace=True)
print("Time series data (first 10 days):\n", df_ts.head(10))
# Resampling
monthly_sales = df_ts.resample('M').sum()
print("\nMonthly sales:\n", monthly_sales.head())
weekly_sales = df_ts.resample('W').mean()
print("\nWeekly average sales:\n", weekly_sales.head())
# Moving averages
df_ts['sales_7d_ma'] = df_ts['sales'].rolling(window=7).mean()
df_ts['sales_30d_ma'] = df_ts['sales'].rolling(window=30).mean()
print("\nWith moving averages (last 5 days):\n", df_ts.tail(5))
# Date-based filtering
q1_sales = df_ts['2023-01':'2023-03']
print("\nQ1 2023 total sales:", q1_sales['sales'].sum())
# Lag features
df_ts['sales_lag_1'] = df_ts['sales'].shift(1)
df_ts['sales_lag_7'] = df_ts['sales'].shift(7)
print("\nWith lag features (last 3 days):\n",
df_ts[['sales', 'sales_lag_1', 'sales_lag_7']].tail(3))
# 6. Data Visualization with Pandas
def pandas_visualization():
print("\n=== Data Visualization ===")
# Create comprehensive dataset
np.random.seed(42)
products = ['A', 'B', 'C', 'D', 'E']
regions = ['North', 'South', 'East', 'West']
months = pd.date_range('2023-01-01', '2023-06-30', freq='M')
data = []
for month in months:
for product in products:
for region in regions:
base_sales = np.random.randint(500, 2000)
seasonal_factor = 1 + 0.2 * np.sin(2 * np.pi * month.month / 12)
sales = int(base_sales * seasonal_factor * np.random.normal(1, 0.1))
data.append({
'date': month,
'product': product,
'region': region,
'sales': max(0, sales)
})
df_viz = pd.DataFrame(data)
# Basic statistics
print("Dataset shape:", df_viz.shape)
print("Columns:", df_viz.columns.tolist())
print("\nSales statistics:\n", df_viz['sales'].describe())
# Product performance
product_sales = df_viz.groupby('product')['sales'].sum().sort_values(ascending=False)
print("\nTotal sales by product:\n", product_sales)
# Region performance
region_sales = df_viz.groupby('region')['sales'].sum().sort_values(ascending=False)
print("\nTotal sales by region:\n", region_sales)
# Monthly trends
monthly_sales = df_viz.groupby(df_viz['date'].dt.strftime('%Y-%m'))['sales'].sum()
print("\nMonthly sales trend:\n", monthly_sales)
# Correlation analysis (if we had numeric variables)
numeric_df = df_viz.select_dtypes(include=[np.number])
if len(numeric_df.columns) > 1:
print("\nCorrelation matrix:\n", numeric_df.corr())
# 7. Advanced Data Operations
def advanced_operations():
print("\n=== Advanced Data Operations ===")
# Merge operations
employees = pd.DataFrame({
'emp_id': [1, 2, 3, 4],
'name': ['Alice', 'Bob', 'Charlie', 'Diana'],
'dept_id': [101, 102, 101, 103]
})
departments = pd.DataFrame({
'dept_id': [101, 102, 103, 104],
'dept_name': ['IT', 'Finance', 'Marketing', 'HR'],
'budget': [1000000, 800000, 600000, 400000]
})
# Inner join
merged_inner = pd.merge(employees, departments, on='dept_id', how='inner')
print("Inner join:\n", merged_inner)
# Left join
merged_left = pd.merge(employees, departments, on='dept_id', how='left')
print("\nLeft join:\n", merged_left)
# Concatenation
q1_data = pd.DataFrame({
'month': ['Jan', 'Feb', 'Mar'],
'sales': [1000, 1200, 1100]
})
q2_data = pd.DataFrame({
'month': ['Apr', 'May', 'Jun'],
'sales': [1300, 1400, 1350]
})
half_year = pd.concat([q1_data, q2_data], ignore_index=True)
print("\nConcatenated data:\n", half_year)
# Apply custom functions
performance_data = pd.DataFrame({
'employee': ['Alice', 'Bob', 'Charlie', 'Diana'],
'score': [85, 92, 78, 95],
'projects': [5, 7, 3, 8]
})
def calculate_performance(row):
score_weight = 0.7
projects_weight = 0.3
return (row['score'] * score_weight + row['projects'] * projects_weight * 10)
performance_data['performance_score'] = performance_data.apply(calculate_performance, axis=1)
print("\nPerformance data with calculated score:\n", performance_data)
# Window functions
sales_data = pd.DataFrame({
'date': pd.date_range('2023-01-01', periods=10),
'daily_sales': [100, 120, 110, 130, 125, 140, 135, 150, 145, 160]
})
sales_data['cumulative_sum'] = sales_data['daily_sales'].cumsum()
sales_data['rolling_mean_3'] = sales_data['daily_sales'].rolling(window=3).mean()
sales_data['rolling_std_3'] = sales_data['daily_sales'].rolling(window=3).std()
print("\nSales data with window functions:\n", sales_data)
# Run all examples
if __name__ == "__main__":
print("Pandas Data Science Examples")
print("=" * 40)
data_creation_loading()
data_selection_filtering()
data_cleaning()
data_transformation()
time_series_analysis()
pandas_visualization()
advanced_operations()
print("\n" + "=" * 40)
print("Examples completed successfully!")
💻 Matplotlib 可视化
🟡 intermediate
⭐⭐⭐
使用matplotlib和seaborn创建出版质量图表的完整指南
⏱️ 40 min
🏷️ matplotlib, seaborn, visualization, data-science
Prerequisites:
Python basics, NumPy, pandas, Basic statistics
# Matplotlib and Seaborn Data Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# Set style and figure parameters
plt.style.use('default') # or 'seaborn-v0_8', 'ggplot', etc.
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12
# 1. Basic Plot Types
def basic_plots():
print("=== Basic Plot Types ===")
# Sample data
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
y3 = np.exp(-x/5) * np.sin(x)
# Line plots
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
# Simple line plot
axes[0, 0].plot(x, y1, 'b-', linewidth=2, label='sin(x)')
axes[0, 0].plot(x, y2, 'r--', linewidth=2, label='cos(x)')
axes[0, 0].set_title('Trigonometric Functions')
axes[0, 0].set_xlabel('x')
axes[0, 0].set_ylabel('y')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)
# Scatter plot
np.random.seed(42)
x_scatter = np.random.randn(100)
y_scatter = x_scatter * 0.5 + np.random.randn(100) * 0.3
colors = np.random.rand(100)
scatter = axes[0, 1].scatter(x_scatter, y_scatter, c=colors, alpha=0.7,
cmap='viridis', s=50)
axes[0, 1].set_title('Scatter Plot with Color Mapping')
axes[0, 1].set_xlabel('X values')
axes[0, 1].set_ylabel('Y values')
plt.colorbar(scatter, ax=axes[0, 1])
# Bar plot
categories = ['A', 'B', 'C', 'D', 'E']
values = [23, 45, 56, 78, 32]
bars = axes[1, 0].bar(categories, values, color=['red', 'green', 'blue', 'orange', 'purple'])
axes[1, 0].set_title('Bar Chart')
axes[1, 0].set_xlabel('Categories')
axes[1, 0].set_ylabel('Values')
# Add value labels on bars
for bar, value in zip(bars, values):
axes[1, 0].text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1,
str(value), ha='center', va='bottom')
# Histogram
data_hist = np.random.normal(100, 15, 1000)
axes[1, 1].hist(data_hist, bins=30, alpha=0.7, color='skyblue', edgecolor='black')
axes[1, 1].set_title('Histogram')
axes[1, 1].set_xlabel('Value')
axes[1, 1].set_ylabel('Frequency')
axes[1, 1].axvline(data_hist.mean(), color='red', linestyle='--',
label=f'Mean: {data_hist.mean():.2f}')
axes[1, 1].legend()
plt.tight_layout()
plt.show()
# 2. Advanced Plotting Techniques
def advanced_plots():
print("\n=== Advanced Plotting Techniques ===")
# Create sample dataset
np.random.seed(42)
n_points = 200
# Multiple subplots with different types
fig = plt.figure(figsize=(16, 12))
# Create grid layout
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
# 2D density plot
ax1 = fig.add_subplot(gs[0, 0])
x_2d = np.random.multivariate_normal([0, 0], [[1, 0.5], [0.5, 1]], 1000)
hb = ax1.hexbin(x_2d[:, 0], x_2d[:, 1], gridsize=20, cmap='Blues')
ax1.set_title('2D Hexbin Density Plot')
ax1.set_xlabel('X')
ax1.set_ylabel('Y')
plt.colorbar(hb, ax=ax1)
# Box plot with multiple groups
ax2 = fig.add_subplot(gs[0, 1])
data_box = [np.random.normal(0, std, 100) for std in range(1, 4)]
bp = ax2.boxplot(data_box, patch_artist=True, labels=['Group 1', 'Group 2', 'Group 3'])
colors = ['lightblue', 'lightgreen', 'lightpink']
for patch, color in zip(bp['boxes'], colors):
patch.set_facecolor(color)
ax2.set_title('Box Plot Comparison')
ax2.set_ylabel('Values')
# Violin plot
ax3 = fig.add_subplot(gs[0, 2])
data_violin = [np.random.normal(0, std, 100) for std in range(1, 4)]
vp = ax3.violinplot(data_violin, positions=[1, 2, 3], showmeans=True)
ax3.set_title('Violin Plot')
ax3.set_xticks([1, 2, 3])
ax3.set_xticklabels(['Group 1', 'Group 2', 'Group 3'])
ax3.set_ylabel('Values')
# Error bar plot
ax4 = fig.add_subplot(gs[1, 0])
x_err = np.arange(5)
y_err = np.array([20, 35, 30, 45, 40])
y_err_values = np.array([2, 3, 4, 2, 3])
ax4.errorbar(x_err, y_err, yerr=y_err_values, fmt='o-',
capsize=5, capthick=2, linewidth=2)
ax4.set_title('Error Bar Plot')
ax4.set_xlabel('Category')
ax4.set_ylabel('Mean ± Error')
# Stacked area plot
ax5 = fig.add_subplot(gs[1, 1])
x_area = np.arange(10)
y1_area = np.random.randint(1, 5, 10)
y2_area = np.random.randint(1, 5, 10)
y3_area = np.random.randint(1, 5, 10)
ax5.stackplot(x_area, y1_area, y2_area, y3_area,
labels=['Series 1', 'Series 2', 'Series 3'],
colors=['skyblue', 'lightgreen', 'lightcoral'])
ax5.set_title('Stacked Area Plot')
ax5.set_xlabel('X')
ax5.set_ylabel('Cumulative Values')
ax5.legend()
# Polar plot
ax6 = fig.add_subplot(gs[1, 2], projection='polar')
theta = np.linspace(0, 2*np.pi, 100)
r = np.abs(np.sin(theta) * np.cos(2*theta))
ax6.plot(theta, r, 'b-', linewidth=2)
ax6.fill(theta, r, alpha=0.3)
ax6.set_title('Polar Plot')
# Heatmap
ax7 = fig.add_subplot(gs[2, :2])
data_heatmap = np.random.randn(10, 12)
im = ax7.imshow(data_heatmap, cmap='coolwarm', aspect='auto')
ax7.set_title('Heatmap')
ax7.set_xlabel('Columns')
ax7.set_ylabel('Rows')
plt.colorbar(im, ax=ax7)
# Pie chart
ax8 = fig.add_subplot(gs[2, 2])
sizes = [30, 25, 20, 15, 10]
labels = ['A', 'B', 'C', 'D', 'E']
colors_pie = ['gold', 'lightcoral', 'lightskyblue', 'lightgreen', 'plum']
explode = (0.1, 0, 0, 0, 0) # explode first slice
wedges, texts, autotexts = ax8.pie(sizes, explode=explode, labels=labels, colors=colors_pie,
autopct='%1.1f%%', shadow=True, startangle=90)
ax8.set_title('Pie Chart')
# Enhance text appearance
for autotext in autotexts:
autotext.set_color('white')
autotext.set_weight('bold')
plt.suptitle('Advanced Visualization Techniques', fontsize=16, y=1.02)
plt.show()
# 3. Seaborn Statistical Plots
def seaborn_plots():
print("\n=== Seaborn Statistical Plots ===")
# Create sample datasets
np.random.seed(42)
# Dataset for regression plots
x_reg = np.linspace(0, 10, 100)
y_reg = 2 * x_reg + 1 + np.random.normal(0, 2, 100)
# Dataset for categorical plots
categories = ['A', 'B', 'C', 'D']
data_cat = []
for cat in categories:
data_cat.extend([(cat, np.random.normal(100 + categories.index(cat) * 10, 15))
for _ in range(50)])
df_cat = pd.DataFrame(data_cat, columns=['Category', 'Value'])
# Dataset for distribution plots
data_dist = np.random.gamma(2, 2, 1000)
# Create comprehensive seaborn plot
fig, axes = plt.subplots(2, 3, figsize=(18, 12))
# Regression plot with confidence interval
sns.regplot(x=x_reg, y=y_reg, ax=axes[0, 0],
scatter_kws={'alpha': 0.6}, line_kws={'color': 'red'})
axes[0, 0].set_title('Regression Plot with Confidence Interval')
# Box plot with swarm overlay
sns.boxplot(x='Category', y='Value', data=df_cat, ax=axes[0, 1])
sns.swarmplot(x='Category', y='Value', data=df_cat, ax=axes[0, 1],
color='black', alpha=0.5, size=4)
axes[0, 1].set_title('Box Plot with Swarm Overlay')
# Violin plot
sns.violinplot(x='Category', y='Value', data=df_cat, ax=axes[0, 2])
axes[0, 2].set_title('Violin Plot by Category')
# Distribution plot (histogram + KDE)
sns.histplot(data_dist, kde=True, ax=axes[1, 0])
axes[1, 0].set_title('Distribution Plot with KDE')
# Pair plot (requires 2D data)
# Creating correlation data
np.random.seed(42)
n_samples = 200
corr_data = {
'Variable1': np.random.normal(0, 1, n_samples),
'Variable2': np.random.normal(0, 1, n_samples),
'Variable3': np.random.normal(0, 1, n_samples)
}
# Add some correlation
corr_data['Variable2'] = 0.7 * corr_data['Variable1'] + 0.3 * corr_data['Variable2']
df_corr = pd.DataFrame(corr_data)
# Correlation heatmap
correlation_matrix = df_corr.corr()
sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0,
square=True, ax=axes[1, 1])
axes[1, 1].set_title('Correlation Heatmap')
# Count plot
count_data = np.random.choice(categories, 100)
sns.countplot(x=count_data, ax=axes[1, 2])
axes[1, 2].set_title('Count Plot')
axes[1, 2].set_xlabel('Category')
plt.tight_layout()
plt.show()
# 4. Real-world Data Visualization
def real_world_visualization():
print("\n=== Real-world Data Visualization ===")
# Create realistic sales dataset
np.random.seed(42)
# Time series sales data
dates = pd.date_range('2022-01-01', '2023-12-31', freq='D')
base_sales = 1000
trend = np.linspace(0, 500, len(dates))
seasonal = 200 * np.sin(2 * np.pi * np.arange(len(dates)) / 365.25)
noise = np.random.normal(0, 50, len(dates))
sales = base_sales + trend + seasonal + noise
sales[sales < 0] = 0 # Ensure non-negative
df_sales = pd.DataFrame({
'date': dates,
'sales': sales
})
# Create monthly aggregations
df_monthly = df_sales.set_index('date').resample('M').agg({
'sales': ['sum', 'mean', 'std']
}).reset_index()
df_monthly.columns = ['date', 'total_sales', 'avg_sales', 'sales_std']
# Create comprehensive dashboard
fig = plt.figure(figsize=(20, 15))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
# 1. Time series plot with trend
ax1 = fig.add_subplot(gs[0, :])
ax1.plot(df_sales['date'], df_sales['sales'], alpha=0.7, linewidth=1, label='Daily Sales')
# Add moving average
df_sales['MA30'] = df_sales['sales'].rolling(window=30).mean()
ax1.plot(df_sales['date'], df_sales['MA30'], 'r-', linewidth=2, label='30-day MA')
ax1.set_title('Daily Sales with Moving Average', fontsize=14, fontweight='bold')
ax1.set_xlabel('Date')
ax1.set_ylabel('Sales')
ax1.legend()
ax1.grid(True, alpha=0.3)
# 2. Monthly sales bar chart
ax2 = fig.add_subplot(gs[1, 0])
monthly_totals = df_sales.set_index('date').resample('M').sum()
ax2.bar(range(len(monthly_totals)), monthly_totals['sales'],
color=plt.cm.viridis(np.linspace(0, 1, len(monthly_totals))))
ax2.set_title('Monthly Total Sales')
ax2.set_xlabel('Month')
ax2.set_ylabel('Total Sales')
ax2.set_xticks(range(0, len(monthly_totals), 3))
ax2.set_xticklabels([month.strftime('%Y-%m') for month in monthly_totals.index[::3]],
rotation=45)
# 3. Sales distribution
ax3 = fig.add_subplot(gs[1, 1])
ax3.hist(df_sales['sales'], bins=50, alpha=0.7, color='skyblue', edgecolor='black')
ax3.axvline(df_sales['sales'].mean(), color='red', linestyle='--',
label=f'Mean: {df_sales["sales"].mean():.0f}')
ax3.axvline(df_sales['sales'].median(), color='orange', linestyle='--',
label=f'Median: {df_sales["sales"].median():.0f}')
ax3.set_title('Sales Distribution')
ax3.set_xlabel('Sales')
ax3.set_ylabel('Frequency')
ax3.legend()
# 4. Monthly statistics
ax4 = fig.add_subplot(gs[1, 2])
months = [month.strftime('%Y-%m') for month in monthly_totals.index]
ax4.plot(monthly_totals.index, monthly_totals['sales'], 'o-', label='Monthly Sales')
ax4.fill_between(monthly_totals.index, monthly_totals['sales'], alpha=0.3)
ax4.set_title('Monthly Sales Trend')
ax4.set_xlabel('Month')
ax4.set_ylabel('Sales')
ax4.set_xticks(range(0, len(monthly_totals), 3))
ax4.set_xticklabels([months[i] for i in range(0, len(monthly_totals), 3)],
rotation=45)
ax4.legend()
# 5. Seasonal pattern
ax5 = fig.add_subplot(gs[2, 0])
df_sales['month'] = df_sales['date'].dt.month
monthly_avg = df_sales.groupby('month')['sales'].mean()
ax5.bar(monthly_avg.index, monthly_avg.values, color='lightgreen')
ax5.set_title('Average Sales by Month')
ax5.set_xlabel('Month')
ax5.set_ylabel('Average Sales')
ax5.set_xticks(range(1, 13))
ax5.set_xticklabels(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
# 6. Day of week analysis
ax6 = fig.add_subplot(gs[2, 1])
df_sales['day_of_week'] = df_sales['date'].dt.day_name()
dow_avg = df_sales.groupby('day_of_week')['sales'].mean()
dow_order = ['Monday', 'Tuesday', 'Wednesday', 'Thursday', 'Friday', 'Saturday', 'Sunday']
dow_avg = dow_avg.reindex(dow_order)
bars = ax6.bar(dow_avg.index, dow_avg.values, color='coral')
ax6.set_title('Average Sales by Day of Week')
ax6.set_xlabel('Day of Week')
ax6.set_ylabel('Average Sales')
ax6.tick_params(axis='x', rotation=45)
# Add value labels on bars
for bar, value in zip(bars, dow_avg.values):
ax6.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 10,
f'{value:.0f}', ha='center', va='bottom')
# 7. Sales heatmap (by month and day of week)
ax7 = fig.add_subplot(gs[2, 2])
df_sales['year_month'] = df_sales['date'].dt.to_period('M')
heatmap_data = df_sales.groupby(['year_month', 'day_of_week'])['sales'].mean().unstack()
heatmap_data = heatmap_data.reindex(columns=dow_order)
sns.heatmap(heatmap_data.T, cmap='YlOrRd', ax=ax7, cbar_kws={'label': 'Average Sales'})
ax7.set_title('Sales Heatmap (Month vs Day of Week)')
ax7.set_xlabel('Month')
ax7.set_ylabel('Day of Week')
plt.suptitle('Sales Analytics Dashboard', fontsize=16, fontweight='bold', y=1.02)
plt.show()
# 5. Publication Quality Plots
def publication_quality():
print("\n=== Publication Quality Plots ===")
# Create high-quality scientific plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
# Plot 1: Comparison of methods with error bars
methods = ['Method A', 'Method B', 'Method C', 'Method D']
means = [85.2, 78.9, 92.1, 88.7]
stds = [3.2, 4.1, 2.8, 3.5]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728']
bars = ax1.bar(methods, means, yerr=stds, capsize=5,
color=colors, alpha=0.8, edgecolor='black', linewidth=1.2)
ax1.set_title('Algorithm Performance Comparison', fontsize=14, fontweight='bold')
ax1.set_ylabel('Accuracy (%)', fontsize=12)
ax1.set_ylim(0, 100)
ax1.grid(True, alpha=0.3, axis='y')
# Add significance markers
significance = ['***', 'ns', '****', '**']
for i, (bar, sig) in enumerate(zip(bars, significance)):
ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + stds[i] + 2,
sig, ha='center', va='bottom', fontsize=12, fontweight='bold')
# Plot 2: ROC curves
fpr = np.linspace(0, 1, 100)
# Simulate ROC curves for different models
tpr_model1 = 1 - np.exp(-3 * fpr)
tpr_model2 = fpr**0.5
tpr_model3 = np.sqrt(fpr)
ax2.plot(fpr, tpr_model1, 'b-', linewidth=2, label=f'Model 1 (AUC = 0.91)')
ax2.plot(fpr, tpr_model2, 'r-', linewidth=2, label=f'Model 2 (AUC = 0.82)')
ax2.plot(fpr, tpr_model3, 'g-', linewidth=2, label=f'Model 3 (AUC = 0.76)')
ax2.plot([0, 1], [0, 1], 'k--', linewidth=1, label='Random')
ax2.set_title('ROC Curves Comparison', fontsize=14, fontweight='bold')
ax2.set_xlabel('False Positive Rate', fontsize=12)
ax2.set_ylabel('True Positive Rate', fontsize=12)
ax2.legend(loc='lower right')
ax2.grid(True, alpha=0.3)
# Style improvements for publication
for ax in [ax1, ax2]:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(1.2)
ax.spines['bottom'].set_linewidth(1.2)
ax.tick_params(axis='both', which='major', labelsize=10)
plt.tight_layout()
plt.show()
# Main execution
if __name__ == "__main__":
print("Matplotlib and Seaborn Visualization Examples")
print("=" * 50)
basic_plots()
advanced_plots()
seaborn_plots()
real_world_visualization()
publication_quality()
print("\n" + "=" * 50)
print("All visualization examples completed successfully!")