tradai-common¶
Shared utilities, base classes, and AWS integrations for the TradAI platform.
Base Classes¶
BaseService¶
Base class for all TradAI microservices with lifecycle management.
from tradai.common import BaseService
class MyService(BaseService):
"""Custom service implementation."""
async def start(self) -> None:
"""Initialize service resources."""
await super().start()
# Custom initialization
async def stop(self) -> None:
"""Cleanup service resources."""
# Custom cleanup
await super().stop()
Methods:
| Method | Description |
|---|---|
start() | Initialize service, called on startup |
stop() | Cleanup resources, called on shutdown |
health_check() | Return health status |
Settings¶
Pydantic settings base class with environment variable support.
from pydantic_settings import BaseSettings, SettingsConfigDict
class MySettings(BaseSettings):
database_url: str
debug: bool = False
model_config = SettingsConfigDict(
env_prefix="MY_SERVICE_",
env_file=".env",
extra="ignore",
)
Features: - Automatic environment variable loading - Type validation - Default values - Nested configuration support
LoggerMixin¶
Mixin class providing consistent logging across all components.
from tradai.common import LoggerMixin
class MyClass(LoggerMixin):
def do_something(self):
self.logger.info("Doing something")
self.logger.debug("Debug details", extra={"key": "value"})
Log Levels: - DEBUG - Detailed diagnostic information - INFO - Key operational events - WARNING - Recoverable issues - ERROR - Failures requiring attention
Entities¶
BacktestConfig¶
Configuration for backtest execution.
from tradai.common import BacktestConfig
config = BacktestConfig(
strategy="TrendFollowingStrategy",
pairs=["BTC/USDT:USDT", "ETH/USDT:USDT"],
timeframe="1h",
date_range=DateRange(start=start_date, end=end_date),
stake_amount=100.0,
dry_run_wallet=10000.0
)
Fields:
| Field | Type | Description |
|---|---|---|
strategy | str | Strategy class name |
pairs | list[str] | Trading pairs |
timeframe | str | Candle timeframe |
date_range | DateRange | Backtest period |
stake_amount | float | Amount per trade |
dry_run_wallet | float | Starting capital |
BacktestResult¶
Results from a completed backtest.
from tradai.common import BacktestResult
# Access result fields
print(f"Sharpe Ratio: {result.sharpe_ratio}")
print(f"Total Profit: {result.total_profit_pct}%")
print(f"Max Drawdown: {result.max_drawdown_pct}%")
Fields:
| Field | Type | Description |
|---|---|---|
total_trades | int | Number of trades |
winning_trades | int | Profitable trades |
losing_trades | int | Unprofitable trades |
total_profit | float | Absolute profit |
total_profit_pct | float | Profit percentage |
sharpe_ratio | float | Risk-adjusted return |
sortino_ratio | float | Downside risk-adjusted return |
profit_factor | float | Gross profit / gross loss |
win_rate | float | Win percentage |
max_drawdown_pct | float | Maximum drawdown |
JobStatus¶
Enumeration of backtest job states.
from tradai.common import JobStatus
status = JobStatus.PENDING # Job queued
status = JobStatus.RUNNING # Job executing
status = JobStatus.COMPLETED # Job finished successfully
status = JobStatus.FAILED # Job failed with error
status = JobStatus.CANCELLED # Job cancelled by user
Exceptions¶
Exception Hierarchy¶
TradAIError (base)
├── ValidationError # Input validation failures
├── ConfigurationError # Configuration issues
├── NotFoundError # Resource not found
├── DataFetchError # Data retrieval failures
├── DataNotFoundError # Missing data
├── StorageError # Storage operations failed
├── BacktestError # Backtest execution failed
├── TradingError # Trading operation failed
├── AuthenticationError # Auth failures
└── CircuitOpenError # Circuit breaker triggered
Usage:
from tradai.common import ValidationError, DataFetchError
def validate_config(config):
if not config.pairs:
raise ValidationError("At least one pair is required")
async def fetch_data(symbol):
try:
return await exchange.fetch_ohlcv(symbol)
except ExchangeError as e:
raise DataFetchError(f"Failed to fetch {symbol}") from e
Health Checking¶
HealthService¶
Service health management with dependency checks.
from tradai.common import HealthService, HealthChecker
health_service = HealthService()
health_service.add_checker("database", DatabaseHealthChecker(db))
health_service.add_checker("redis", RedisHealthChecker(redis))
# Check health
result = await health_service.check()
print(result.status) # "healthy" or "unhealthy"
Built-in Checkers¶
| Checker | Description |
|---|---|
DatabaseHealthChecker | PostgreSQL/MySQL connectivity |
RedisHealthChecker | Redis connectivity |
HTTPHealthChecker | HTTP endpoint availability |
DynamoDBHealthChecker | DynamoDB table access |
HTTP Client¶
HttpClient¶
Async HTTP client with retry and circuit breaker support.
from tradai.common import HttpClient, HttpClientConfig, RetryConfig
config = HttpClientConfig(
base_url="https://api.example.com",
timeout=30.0,
retry=RetryConfig(max_retries=3, backoff_factor=0.5)
)
async with HttpClient(config) as client:
response = await client.get("/api/v1/data")
data = response.json()
Features: - Automatic retries with exponential backoff - Circuit breaker pattern - Request/response logging - Timeout handling
MLflow Integration¶
MLflowAdapter¶
Wrapper for MLflow experiment tracking.
from tradai.common import MLflowAdapter
adapter = MLflowAdapter(tracking_uri="http://localhost:5000")
# Log experiment
with adapter.start_run(experiment_name="backtest"):
adapter.log_params({"strategy": "TrendFollowing"})
adapter.log_metrics({"sharpe": 1.85, "profit": 25.5})
adapter.log_artifact("results.json")
# Query runs
runs = adapter.search_runs(
experiment_name="backtest",
filter_string="metrics.sharpe > 1.0"
)
ModelVersion¶
MLflow model version metadata.
from tradai.common import ModelVersion
version = ModelVersion(
name="TrendFollowingStrategy",
version="3",
stage="Production",
run_id="abc123"
)
Circuit Breaker¶
CircuitBreaker¶
Prevents cascading failures by failing fast.
from tradai.common import CircuitBreaker, CircuitBreakerConfig
config = CircuitBreakerConfig(
failure_threshold=5,
recovery_timeout=30.0,
half_open_requests=3
)
breaker = CircuitBreaker("external-api", config)
async def call_external():
async with breaker:
return await external_api.call()
States: - CLOSED - Normal operation - OPEN - Failing fast, not calling - HALF_OPEN - Testing recovery
A/B Testing¶
ABTestManager¶
Manage champion/challenger experiments.
from tradai.common import ABTestManager, ABTestConfig
config = ABTestConfig(
champion_version="2",
challenger_version="3",
min_trades=100,
significance_level=0.05
)
manager = ABTestManager(config)
manager.add_trade("champion", profit_pct=1.5)
manager.add_trade("challenger", profit_pct=2.1)
result = manager.evaluate()
print(result.recommendation) # "PROMOTE_CHALLENGER"
Model Comparison¶
ModelComparator¶
Compare model versions for promotion decisions.
from tradai.common import ModelComparator, ModelCandidate
comparator = ModelComparator()
champion = ModelCandidate(version="2", metrics={"sharpe": 1.5})
challenger = ModelCandidate(version="3", metrics={"sharpe": 1.8})
result = comparator.compare(champion, challenger)
print(result.decision) # "PROMOTE"
print(result.confidence) # 0.85
Drift Detection¶
DriftDetector¶
Detect feature and prediction drift.
from tradai.common import DriftDetector, DriftThresholds
detector = DriftDetector(
thresholds=DriftThresholds(
psi_threshold=0.2,
ks_threshold=0.1
)
)
result = detector.detect(
reference_data=training_df,
current_data=production_df
)
if result.has_drift:
print(f"Drift detected: {result.drifted_features}")
Utilities¶
with_retry¶
Decorator for automatic retries with backoff.
from tradai.common import with_retry
@with_retry(max_attempts=3, backoff_factor=2.0)
async def fetch_data():
return await api.call()
FreqtradeCLIBuilder¶
Build Freqtrade CLI commands programmatically.
from tradai.common import FreqtradeCLIBuilder
cmd = FreqtradeCLIBuilder()
cmd.backtesting()
cmd.strategy("MyStrategy")
cmd.timerange("20240101-20240301")
cmd.pairs(["BTC/USDT:USDT"])
args = cmd.build() # ['backtesting', '-s', 'MyStrategy', ...]
AWS Module¶
The AWS module provides pre-configured, thread-safe clients for AWS services.
from tradai.common.aws import (
DynamoDBAdapter, AsyncDynamoDBAdapter,
SNSPublisher, MetricsPublisher,
AsyncS3Client, S3ConfigRepository,
get_secret,
)
DynamoDBAdapter¶
Thread-safe DynamoDB operations with automatic serialization.
from tradai.common.aws import DynamoDBAdapter
adapter = DynamoDBAdapter(table_name="tradai-state")
# Put item
adapter.put_item({"pk": "strategy-1", "status": "running"})
# Get item
item = adapter.get_item({"pk": "strategy-1"})
# Query with filter
items = adapter.query(
key_condition="pk = :pk",
expression_values={":pk": "strategy-1"}
)
SNSPublisher¶
Publish alerts and notifications to SNS topics.
from tradai.common.aws import SNSPublisher
publisher = SNSPublisher(topic_arn="arn:aws:sns:...")
publisher.publish(
subject="Alert: Model Drift Detected",
message="Model PascalStrategy shows significant drift",
message_attributes={"severity": "HIGH", "model": "PascalStrategy"}
)
MetricsPublisher¶
Publish CloudWatch custom metrics.
from tradai.common.aws import MetricsPublisher, MetricDatum, MetricUnit
publisher = MetricsPublisher(namespace="TradAI/Backtests")
publisher.publish([
MetricDatum(
name="BacktestDuration",
value=125.5,
unit=MetricUnit.SECONDS,
dimensions={"Strategy": "PascalStrategy"}
),
MetricDatum(
name="SharpeRatio",
value=1.85,
unit=MetricUnit.NONE,
dimensions={"Strategy": "PascalStrategy"}
)
])
AsyncS3Client¶
Async S3 operations for data storage.
from tradai.common.aws import AsyncS3Client
async with AsyncS3Client() as client:
# Upload data
await client.upload_to_s3(
bucket="tradai-data",
key="backtests/result.json",
data=json.dumps(result)
)
# Download data
content = await client.download_from_s3(
bucket="tradai-data",
key="backtests/result.json"
)
DynamoDBStateRepository¶
Generic state repository for Lambda handlers.
from tradai.common.aws.state_repository import DynamoDBStateRepository
from tradai.common.lambda_ import HealthState
repo = DynamoDBStateRepository(
table_name="tradai-health-state",
key_name="service_name",
entity_class=HealthState,
ttl_seconds=86400 # 24 hours
)
# Get/put state
state = repo.get("backend-api")
repo.put(HealthState(service_name="backend-api", status="healthy"))
get_secret¶
Retrieve secrets from AWS Secrets Manager.
from tradai.common.aws import get_secret
# Returns parsed JSON or raw string
db_creds = get_secret("tradai/database/credentials")
api_key = get_secret("tradai/exchange/api-key")
Lambda Module¶
The Lambda module provides decorators and utilities for building Lambda handlers.
from tradai.common.lambda_ import (
lambda_handler,
LambdaContext,
LambdaResponse,
LambdaSettings,
)
@lambda_handler Decorator¶
Wraps Lambda handlers with automatic setup, error handling, and warm start optimization.
from tradai.common.lambda_ import (
lambda_handler,
LambdaContext,
LambdaResponse,
HealthCheckSettings,
)
@lambda_handler(HealthCheckSettings, cloudwatch_namespace_suffix="ServiceHealth")
def handler(event: dict, ctx: LambdaContext[HealthCheckSettings]) -> dict:
"""My Lambda handler."""
settings = ctx.settings
# Use pre-configured publishers
ctx.metrics_publisher.publish([...])
ctx.alert_publisher.publish(subject="...", message="...")
return LambdaResponse.success(data={"healthy": True}).to_dict()
Features:
| Feature | Description |
|---|---|
| Settings loading | Automatic env var loading via Pydantic |
| Context creation | Pre-configured SNS, CloudWatch publishers |
| Warm start caching | 55-minute TTL cache for context |
| Error handling | Consistent error response format |
| Step Functions | Optional step_functions=True for SF-compatible errors |
LambdaContext¶
Dependency container injected into handlers.
# Accessed via handler's ctx parameter
ctx.settings # LambdaSettings instance
ctx.logger # Configured logger
ctx.metrics_publisher # MetricsPublisher for CloudWatch
ctx.alert_publisher # SNSPublisher for alerts (if enabled)
LambdaResponse¶
Builder for consistent Lambda responses.
from tradai.common.lambda_ import LambdaResponse
# Success response (API Gateway format)
LambdaResponse.success(data={"result": "ok"}).to_dict()
# {"statusCode": 200, "body": {"success": true, "data": {...}}}
# For Step Functions
LambdaResponse.success(data={"result": "ok"}).to_step_functions()
# {"statusCode": 200, "body": {...}} - SF accesses via $.Payload.body.data
# Error response
LambdaResponse.error(
message="Validation failed",
error_type="ValidationError"
).to_dict()
LambdaSettings Classes¶
Pre-built settings for common Lambda patterns:
| Class | Use Case |
|---|---|
LambdaSettings | Base class with common fields |
DynamoDBSettings | + DynamoDB table config |
HealthCheckSettings | Health check Lambdas |
DriftMonitorSettings | Drift detection Lambdas |
RetrainingSchedulerSettings | Model retraining Lambdas |
ModelManagementSettings | MLflow operations |
Entity Classes¶
DynamoDB entities for state tracking:
from tradai.common.lambda_ import HealthState, DriftState, HeartbeatState
# All entities implement to_dynamodb_item() and from_dynamodb_item()
state = HealthState(
service_name="backend-api",
status="healthy",
consecutive_failures=0,
last_check=datetime.now(UTC).isoformat()
)
item = state.to_dynamodb_item() # For DynamoDB put_item
Clients Module¶
Service clients for inter-service communication with circuit breaker integration.
DataCollectionClient¶
Client for the data-collection service.
from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig
config = DataCollectionClientConfig(
base_url="http://data-collection.tradai.local:8002",
timeout=30.0
)
client = DataCollectionClient(config)
# Sync market data
result = client.sync_data(
symbols=["BTC/USDT:USDT", "ETH/USDT:USDT"],
start_date="2024-01-01",
end_date="2024-01-31",
)
# Check data freshness
freshness = client.check_freshness(
symbols=["BTC/USDT:USDT"],
stale_threshold_hours=24
)
if not freshness.all_fresh:
print(f"Stale symbols: {freshness.stale_symbols}")
# List available symbols
symbols = client.list_symbols()
StrategyServiceClient¶
Client for the strategy-service.
from tradai.common.clients import StrategyServiceClient, StrategyClientConfig
config = StrategyClientConfig(
base_url="http://strategy-service.tradai.local:8003"
)
client = StrategyServiceClient(config)
# List strategies
strategies = client.list_strategies()
# Get strategy details
strategy = client.get_strategy("PascalStrategy")
# Submit backtest
job = client.submit_backtest(
strategy="PascalStrategy",
pairs=["BTC/USDT:USDT"],
timeframe="1h",
start_date="2024-01-01",
end_date="2024-01-31"
)
Response Types¶
from tradai.common.clients import (
SyncResponse,
FreshnessResponse,
SymbolFreshness,
StrategyListResponse,
)
# Type-safe response handling
freshness: FreshnessResponse = client.check_freshness(...)
for symbol in freshness.stale_symbols:
print(f"{symbol.symbol}: last_updated={symbol.last_updated}")
Validation Module¶
Validation entities for dry-run and go-live deployments.
from tradai.common.validation import (
DryRunValidationReport,
GoLiveValidationReport,
HeartbeatMetrics,
CandleMetrics,
)
DryRunValidationReport¶
Report for validating strategy behavior in dry-run mode.
from tradai.common.validation import (
DryRunValidationReport,
HeartbeatMetrics,
CandleMetrics,
OrderMetrics,
)
heartbeat = HeartbeatMetrics(
total_expected=1440, # 24h at 1-min intervals
total_received=1435,
max_gap_minutes=5,
uptime_pct=99.6
)
candle = CandleMetrics(
total_candles=1440,
missing_candles=2,
delayed_candles=5,
completeness_pct=99.5
)
report = DryRunValidationReport(
strategy_id="pascal-v2",
heartbeat_metrics=heartbeat,
candle_metrics=candle,
checks=[...],
)
if report.overall_passed:
print("Ready for go-live!")
else:
print(f"Failed checks: {report.failed_checks}")
GoLiveValidationReport¶
Report for validating infrastructure readiness.
from tradai.common.validation import (
GoLiveValidationReport,
GoLiveCheckResult,
GoLiveCheckStatus,
AlarmCheckResult,
LambdaCheckResult,
)
alarm_check = AlarmCheckResult(
alarm_name="high-drawdown-alarm",
exists=True,
enabled=True,
status=GoLiveCheckStatus.PASSED
)
lambda_check = LambdaCheckResult(
function_name="model-rollback",
exists=True,
memory_mb=256,
timeout_seconds=120,
status=GoLiveCheckStatus.PASSED
)
report = GoLiveValidationReport(
strategy_id="pascal-v2",
environment="prod",
alarm_checks=[alarm_check],
lambda_checks=[lambda_check],
)
print(f"Ready: {report.all_passed}")
Check Severities¶
| Severity | Impact |
|---|---|
ERROR | Blocks deployment |
WARNING | Review recommended |
INFO | Informational only |
Complete Examples¶
Lambda Handler Pattern¶
Complete example of a Lambda handler with settings, state management, and metrics.
from datetime import datetime, UTC
from typing import Any
from tradai.common.lambda_ import (
lambda_handler,
LambdaContext,
LambdaResponse,
LambdaSettings,
HealthState,
)
from tradai.common.aws import DynamoDBAdapter, MetricDatum, MetricUnit
class MyLambdaSettings(LambdaSettings):
"""Settings for my Lambda function."""
service_discovery_namespace: str = "tradai.local"
health_timeout_seconds: int = 30
consecutive_failure_threshold: int = 3
@lambda_handler(MyLambdaSettings, cloudwatch_namespace_suffix="MyLambda")
def handler(event: dict, ctx: LambdaContext[MyLambdaSettings]) -> dict:
"""Process health check for services."""
services = event.get("services", [])
results = []
# Initialize state repository
state_repo = DynamoDBAdapter(table_name=ctx.settings.dynamodb_table_name)
for service in services:
# Check service health
healthy, latency = check_service(service, ctx.settings.health_timeout_seconds)
# Get/update state
state_key = {"pk": f"HEALTH#{service['name']}"}
existing = state_repo.get_item(state_key)
if healthy:
consecutive_failures = 0
else:
consecutive_failures = (existing.get("consecutive_failures", 0) + 1) if existing else 1
# Save updated state
state_repo.put_item({
**state_key,
"status": "healthy" if healthy else "unhealthy",
"consecutive_failures": consecutive_failures,
"last_check": datetime.now(UTC).isoformat(),
"latency_ms": latency,
})
# Publish metrics
ctx.metrics_publisher.publish([
MetricDatum(
name="ServiceHealthy",
value=1 if healthy else 0,
unit=MetricUnit.COUNT,
dimensions={"Service": service["name"]}
),
MetricDatum(
name="HealthCheckLatency",
value=latency,
unit=MetricUnit.MILLISECONDS,
dimensions={"Service": service["name"]}
),
])
# Alert if threshold exceeded
if consecutive_failures >= ctx.settings.consecutive_failure_threshold:
ctx.alert_publisher.publish(
subject=f"Service Unhealthy: {service['name']}",
message=f"Service {service['name']} has failed {consecutive_failures} consecutive health checks",
)
results.append({
"service": service["name"],
"healthy": healthy,
"latency_ms": latency,
})
return LambdaResponse.success(data={
"summary": {"healthy": sum(1 for r in results if r["healthy"]), "total": len(results)},
"results": results,
}).to_dict()
def check_service(service: dict, timeout: int) -> tuple[bool, float]:
"""Check if service is healthy. Returns (healthy, latency_ms)."""
import httpx
import time
url = f"http://{service['name']}.tradai.local:{service['port']}{service['path']}"
start = time.time()
try:
response = httpx.get(url, timeout=timeout)
latency = (time.time() - start) * 1000
return response.status_code == 200, latency
except Exception:
latency = (time.time() - start) * 1000
return False, latency
Service with Dependency Injection¶
Complete example of a FastAPI service using DI patterns.
from contextlib import asynccontextmanager
from typing import Annotated
from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel
from tradai.common import LoggerMixin
from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig
# Settings
class Settings(BaseModel):
data_collection_url: str = "http://localhost:8002"
debug: bool = False
# Service class
class BacktestService(LoggerMixin):
"""Business logic for backtesting."""
def __init__(self, data_client: DataCollectionClient):
self._data_client = data_client
async def run_backtest(self, strategy: str, symbols: list[str], timeframe: str) -> dict:
"""Execute a backtest after ensuring data freshness."""
self.logger.info(f"Running backtest for {strategy}", extra={"symbols": symbols})
# Check data freshness
freshness = self._data_client.check_freshness(symbols=symbols)
if not freshness.all_fresh:
self.logger.warning(f"Stale data detected: {freshness.stale_symbols}")
# Trigger sync for stale symbols
self._data_client.sync_data(
symbols=[s.symbol for s in freshness.stale_symbols],
start_date="2024-01-01",
end_date="2024-12-31",
)
# Run backtest logic here...
return {"strategy": strategy, "status": "completed"}
# Dependency injection
def get_settings() -> Settings:
return Settings()
def get_data_client(settings: Annotated[Settings, Depends(get_settings)]) -> DataCollectionClient:
config = DataCollectionClientConfig(base_url=settings.data_collection_url)
return DataCollectionClient(config)
def get_backtest_service(
data_client: Annotated[DataCollectionClient, Depends(get_data_client)]
) -> BacktestService:
return BacktestService(data_client=data_client)
# FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
app = FastAPI(lifespan=lifespan)
@app.post("/api/v1/backtests")
async def create_backtest(
strategy: str,
symbols: list[str],
timeframe: str = "1h",
service: Annotated[BacktestService, Depends(get_backtest_service)] = None,
):
"""Submit a new backtest."""
result = await service.run_backtest(strategy, symbols, timeframe)
return result
Circuit Breaker Pattern¶
Using circuit breaker for resilient external calls.
from tradai.common import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerError
from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig
# Configure circuit breaker
circuit_config = CircuitBreakerConfig(
failure_threshold=5, # Open after 5 failures
recovery_timeout=30.0, # Try again after 30 seconds
half_open_requests=2, # Allow 2 test requests when half-open
)
circuit_breaker = CircuitBreaker(config=circuit_config)
async def fetch_with_circuit_breaker(symbols: list[str]) -> dict:
"""Fetch data with circuit breaker protection."""
# Check if circuit is open
if not circuit_breaker.can_execute():
# Return cached/default data when circuit is open
return {"source": "cache", "data": get_cached_data(symbols)}
try:
client = DataCollectionClient(
DataCollectionClientConfig(base_url="http://data-collection:8002")
)
result = client.check_freshness(symbols=symbols)
# Record success
circuit_breaker.record_success()
return {"source": "live", "data": result}
except Exception as e:
# Record failure
circuit_breaker.record_failure()
if circuit_breaker.is_open:
# Circuit just opened - log and fall back
print(f"Circuit breaker opened after failure: {e}")
# Return cached data
return {"source": "cache", "data": get_cached_data(symbols)}
def get_cached_data(symbols: list[str]) -> dict:
"""Return cached data as fallback."""
return {"symbols": symbols, "cached": True}
See Also¶
Related SDKs:
- tradai-data - Data layer (repositories, adapters)
- tradai-strategy - Strategy framework (TradAIStrategy, preflight)
Architecture:
- Architecture Overview - System diagrams
- DESIGN.md - Design decisions and patterns
Lambdas:
- Lambda Functions - Serverless function reference
- backtest-consumer - Uses ECS utilities
- health-check - Uses HTTP client utilities
Services:
- Backend Service - Uses BaseService, LoggerMixin
- Strategy Service - Uses MLflowAdapter
- Data Collection - Uses CircuitBreaker
CLI:
- CLI Reference - Uses common entities and settings