Skip to content

tradai-common

Shared utilities, base classes, and AWS integrations for the TradAI platform.

from tradai.common import BaseService, LoggerMixin, BacktestConfig

Base Classes

BaseService

Base class for all TradAI microservices with lifecycle management.

from tradai.common import BaseService

class MyService(BaseService):
    """Custom service implementation."""

    async def start(self) -> None:
        """Initialize service resources."""
        await super().start()
        # Custom initialization

    async def stop(self) -> None:
        """Cleanup service resources."""
        # Custom cleanup
        await super().stop()

Methods:

Method Description
start() Initialize service, called on startup
stop() Cleanup resources, called on shutdown
health_check() Return health status

Settings

Pydantic settings base class with environment variable support.

from pydantic_settings import BaseSettings, SettingsConfigDict

class MySettings(BaseSettings):
    database_url: str
    debug: bool = False

    model_config = SettingsConfigDict(
        env_prefix="MY_SERVICE_",
        env_file=".env",
        extra="ignore",
    )

Features: - Automatic environment variable loading - Type validation - Default values - Nested configuration support


LoggerMixin

Mixin class providing consistent logging across all components.

from tradai.common import LoggerMixin

class MyClass(LoggerMixin):
    def do_something(self):
        self.logger.info("Doing something")
        self.logger.debug("Debug details", extra={"key": "value"})

Log Levels: - DEBUG - Detailed diagnostic information - INFO - Key operational events - WARNING - Recoverable issues - ERROR - Failures requiring attention


Entities

BacktestConfig

Configuration for backtest execution.

from tradai.common import BacktestConfig

config = BacktestConfig(
    strategy="TrendFollowingStrategy",
    pairs=["BTC/USDT:USDT", "ETH/USDT:USDT"],
    timeframe="1h",
    date_range=DateRange(start=start_date, end=end_date),
    stake_amount=100.0,
    dry_run_wallet=10000.0
)

Fields:

Field Type Description
strategy str Strategy class name
pairs list[str] Trading pairs
timeframe str Candle timeframe
date_range DateRange Backtest period
stake_amount float Amount per trade
dry_run_wallet float Starting capital

BacktestResult

Results from a completed backtest.

from tradai.common import BacktestResult

# Access result fields
print(f"Sharpe Ratio: {result.sharpe_ratio}")
print(f"Total Profit: {result.total_profit_pct}%")
print(f"Max Drawdown: {result.max_drawdown_pct}%")

Fields:

Field Type Description
total_trades int Number of trades
winning_trades int Profitable trades
losing_trades int Unprofitable trades
total_profit float Absolute profit
total_profit_pct float Profit percentage
sharpe_ratio float Risk-adjusted return
sortino_ratio float Downside risk-adjusted return
profit_factor float Gross profit / gross loss
win_rate float Win percentage
max_drawdown_pct float Maximum drawdown

JobStatus

Enumeration of backtest job states.

from tradai.common import JobStatus

status = JobStatus.PENDING  # Job queued
status = JobStatus.RUNNING  # Job executing
status = JobStatus.COMPLETED  # Job finished successfully
status = JobStatus.FAILED  # Job failed with error
status = JobStatus.CANCELLED  # Job cancelled by user

Exceptions

Exception Hierarchy

TradAIError (base)
├── ValidationError      # Input validation failures
├── ConfigurationError   # Configuration issues
├── NotFoundError        # Resource not found
├── DataFetchError       # Data retrieval failures
├── DataNotFoundError    # Missing data
├── StorageError         # Storage operations failed
├── BacktestError        # Backtest execution failed
├── TradingError         # Trading operation failed
├── AuthenticationError  # Auth failures
└── CircuitOpenError     # Circuit breaker triggered

Usage:

from tradai.common import ValidationError, DataFetchError

def validate_config(config):
    if not config.pairs:
        raise ValidationError("At least one pair is required")

async def fetch_data(symbol):
    try:
        return await exchange.fetch_ohlcv(symbol)
    except ExchangeError as e:
        raise DataFetchError(f"Failed to fetch {symbol}") from e

Health Checking

HealthService

Service health management with dependency checks.

from tradai.common import HealthService, HealthChecker

health_service = HealthService()
health_service.add_checker("database", DatabaseHealthChecker(db))
health_service.add_checker("redis", RedisHealthChecker(redis))

# Check health
result = await health_service.check()
print(result.status)  # "healthy" or "unhealthy"

Built-in Checkers

Checker Description
DatabaseHealthChecker PostgreSQL/MySQL connectivity
RedisHealthChecker Redis connectivity
HTTPHealthChecker HTTP endpoint availability
DynamoDBHealthChecker DynamoDB table access

HTTP Client

HttpClient

Async HTTP client with retry and circuit breaker support.

from tradai.common import HttpClient, HttpClientConfig, RetryConfig

config = HttpClientConfig(
    base_url="https://api.example.com",
    timeout=30.0,
    retry=RetryConfig(max_retries=3, backoff_factor=0.5)
)

async with HttpClient(config) as client:
    response = await client.get("/api/v1/data")
    data = response.json()

Features: - Automatic retries with exponential backoff - Circuit breaker pattern - Request/response logging - Timeout handling


MLflow Integration

MLflowAdapter

Wrapper for MLflow experiment tracking.

from tradai.common import MLflowAdapter

adapter = MLflowAdapter(tracking_uri="http://localhost:5000")

# Log experiment
with adapter.start_run(experiment_name="backtest"):
    adapter.log_params({"strategy": "TrendFollowing"})
    adapter.log_metrics({"sharpe": 1.85, "profit": 25.5})
    adapter.log_artifact("results.json")

# Query runs
runs = adapter.search_runs(
    experiment_name="backtest",
    filter_string="metrics.sharpe > 1.0"
)

ModelVersion

MLflow model version metadata.

from tradai.common import ModelVersion

version = ModelVersion(
    name="TrendFollowingStrategy",
    version="3",
    stage="Production",
    run_id="abc123"
)

Circuit Breaker

CircuitBreaker

Prevents cascading failures by failing fast.

from tradai.common import CircuitBreaker, CircuitBreakerConfig

config = CircuitBreakerConfig(
    failure_threshold=5,
    recovery_timeout=30.0,
    half_open_requests=3
)

breaker = CircuitBreaker("external-api", config)

async def call_external():
    async with breaker:
        return await external_api.call()

States: - CLOSED - Normal operation - OPEN - Failing fast, not calling - HALF_OPEN - Testing recovery


A/B Testing

ABTestManager

Manage champion/challenger experiments.

from tradai.common import ABTestManager, ABTestConfig

config = ABTestConfig(
    champion_version="2",
    challenger_version="3",
    min_trades=100,
    significance_level=0.05
)

manager = ABTestManager(config)
manager.add_trade("champion", profit_pct=1.5)
manager.add_trade("challenger", profit_pct=2.1)

result = manager.evaluate()
print(result.recommendation)  # "PROMOTE_CHALLENGER"

Model Comparison

ModelComparator

Compare model versions for promotion decisions.

from tradai.common import ModelComparator, ModelCandidate

comparator = ModelComparator()

champion = ModelCandidate(version="2", metrics={"sharpe": 1.5})
challenger = ModelCandidate(version="3", metrics={"sharpe": 1.8})

result = comparator.compare(champion, challenger)
print(result.decision)  # "PROMOTE"
print(result.confidence)  # 0.85

Drift Detection

DriftDetector

Detect feature and prediction drift.

from tradai.common import DriftDetector, DriftThresholds

detector = DriftDetector(
    thresholds=DriftThresholds(
        psi_threshold=0.2,
        ks_threshold=0.1
    )
)

result = detector.detect(
    reference_data=training_df,
    current_data=production_df
)

if result.has_drift:
    print(f"Drift detected: {result.drifted_features}")

Utilities

with_retry

Decorator for automatic retries with backoff.

from tradai.common import with_retry

@with_retry(max_attempts=3, backoff_factor=2.0)
async def fetch_data():
    return await api.call()

FreqtradeCLIBuilder

Build Freqtrade CLI commands programmatically.

from tradai.common import FreqtradeCLIBuilder

cmd = FreqtradeCLIBuilder()
cmd.backtesting()
cmd.strategy("MyStrategy")
cmd.timerange("20240101-20240301")
cmd.pairs(["BTC/USDT:USDT"])

args = cmd.build()  # ['backtesting', '-s', 'MyStrategy', ...]

AWS Module

The AWS module provides pre-configured, thread-safe clients for AWS services.

from tradai.common.aws import (
    DynamoDBAdapter, AsyncDynamoDBAdapter,
    SNSPublisher, MetricsPublisher,
    AsyncS3Client, S3ConfigRepository,
    get_secret,
)

DynamoDBAdapter

Thread-safe DynamoDB operations with automatic serialization.

from tradai.common.aws import DynamoDBAdapter

adapter = DynamoDBAdapter(table_name="tradai-state")

# Put item
adapter.put_item({"pk": "strategy-1", "status": "running"})

# Get item
item = adapter.get_item({"pk": "strategy-1"})

# Query with filter
items = adapter.query(
    key_condition="pk = :pk",
    expression_values={":pk": "strategy-1"}
)

SNSPublisher

Publish alerts and notifications to SNS topics.

from tradai.common.aws import SNSPublisher

publisher = SNSPublisher(topic_arn="arn:aws:sns:...")

publisher.publish(
    subject="Alert: Model Drift Detected",
    message="Model PascalStrategy shows significant drift",
    message_attributes={"severity": "HIGH", "model": "PascalStrategy"}
)

MetricsPublisher

Publish CloudWatch custom metrics.

from tradai.common.aws import MetricsPublisher, MetricDatum, MetricUnit

publisher = MetricsPublisher(namespace="TradAI/Backtests")

publisher.publish([
    MetricDatum(
        name="BacktestDuration",
        value=125.5,
        unit=MetricUnit.SECONDS,
        dimensions={"Strategy": "PascalStrategy"}
    ),
    MetricDatum(
        name="SharpeRatio",
        value=1.85,
        unit=MetricUnit.NONE,
        dimensions={"Strategy": "PascalStrategy"}
    )
])

AsyncS3Client

Async S3 operations for data storage.

from tradai.common.aws import AsyncS3Client

async with AsyncS3Client() as client:
    # Upload data
    await client.upload_to_s3(
        bucket="tradai-data",
        key="backtests/result.json",
        data=json.dumps(result)
    )

    # Download data
    content = await client.download_from_s3(
        bucket="tradai-data",
        key="backtests/result.json"
    )

DynamoDBStateRepository

Generic state repository for Lambda handlers.

from tradai.common.aws.state_repository import DynamoDBStateRepository
from tradai.common.lambda_ import HealthState

repo = DynamoDBStateRepository(
    table_name="tradai-health-state",
    key_name="service_name",
    entity_class=HealthState,
    ttl_seconds=86400  # 24 hours
)

# Get/put state
state = repo.get("backend-api")
repo.put(HealthState(service_name="backend-api", status="healthy"))

get_secret

Retrieve secrets from AWS Secrets Manager.

from tradai.common.aws import get_secret

# Returns parsed JSON or raw string
db_creds = get_secret("tradai/database/credentials")
api_key = get_secret("tradai/exchange/api-key")

Lambda Module

The Lambda module provides decorators and utilities for building Lambda handlers.

from tradai.common.lambda_ import (
    lambda_handler,
    LambdaContext,
    LambdaResponse,
    LambdaSettings,
)

@lambda_handler Decorator

Wraps Lambda handlers with automatic setup, error handling, and warm start optimization.

from tradai.common.lambda_ import (
    lambda_handler,
    LambdaContext,
    LambdaResponse,
    HealthCheckSettings,
)

@lambda_handler(HealthCheckSettings, cloudwatch_namespace_suffix="ServiceHealth")
def handler(event: dict, ctx: LambdaContext[HealthCheckSettings]) -> dict:
    """My Lambda handler."""
    settings = ctx.settings

    # Use pre-configured publishers
    ctx.metrics_publisher.publish([...])
    ctx.alert_publisher.publish(subject="...", message="...")

    return LambdaResponse.success(data={"healthy": True}).to_dict()

Features:

Feature Description
Settings loading Automatic env var loading via Pydantic
Context creation Pre-configured SNS, CloudWatch publishers
Warm start caching 55-minute TTL cache for context
Error handling Consistent error response format
Step Functions Optional step_functions=True for SF-compatible errors

LambdaContext

Dependency container injected into handlers.

# Accessed via handler's ctx parameter
ctx.settings         # LambdaSettings instance
ctx.logger           # Configured logger
ctx.metrics_publisher # MetricsPublisher for CloudWatch
ctx.alert_publisher   # SNSPublisher for alerts (if enabled)

LambdaResponse

Builder for consistent Lambda responses.

from tradai.common.lambda_ import LambdaResponse

# Success response (API Gateway format)
LambdaResponse.success(data={"result": "ok"}).to_dict()
# {"statusCode": 200, "body": {"success": true, "data": {...}}}

# For Step Functions
LambdaResponse.success(data={"result": "ok"}).to_step_functions()
# {"statusCode": 200, "body": {...}}  - SF accesses via $.Payload.body.data

# Error response
LambdaResponse.error(
    message="Validation failed",
    error_type="ValidationError"
).to_dict()

LambdaSettings Classes

Pre-built settings for common Lambda patterns:

Class Use Case
LambdaSettings Base class with common fields
DynamoDBSettings + DynamoDB table config
HealthCheckSettings Health check Lambdas
DriftMonitorSettings Drift detection Lambdas
RetrainingSchedulerSettings Model retraining Lambdas
ModelManagementSettings MLflow operations

Entity Classes

DynamoDB entities for state tracking:

from tradai.common.lambda_ import HealthState, DriftState, HeartbeatState

# All entities implement to_dynamodb_item() and from_dynamodb_item()
state = HealthState(
    service_name="backend-api",
    status="healthy",
    consecutive_failures=0,
    last_check=datetime.now(UTC).isoformat()
)

item = state.to_dynamodb_item()  # For DynamoDB put_item

Clients Module

Service clients for inter-service communication with circuit breaker integration.

from tradai.common.clients import (
    DataCollectionClient,
    StrategyServiceClient,
    MLflowClient,
)

DataCollectionClient

Client for the data-collection service.

from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig

config = DataCollectionClientConfig(
    base_url="http://data-collection.tradai.local:8002",
    timeout=30.0
)

client = DataCollectionClient(config)

# Sync market data
result = client.sync_data(
    symbols=["BTC/USDT:USDT", "ETH/USDT:USDT"],
    start_date="2024-01-01",
    end_date="2024-01-31",
)

# Check data freshness
freshness = client.check_freshness(
    symbols=["BTC/USDT:USDT"],
    stale_threshold_hours=24
)
if not freshness.all_fresh:
    print(f"Stale symbols: {freshness.stale_symbols}")

# List available symbols
symbols = client.list_symbols()

StrategyServiceClient

Client for the strategy-service.

from tradai.common.clients import StrategyServiceClient, StrategyClientConfig

config = StrategyClientConfig(
    base_url="http://strategy-service.tradai.local:8003"
)

client = StrategyServiceClient(config)

# List strategies
strategies = client.list_strategies()

# Get strategy details
strategy = client.get_strategy("PascalStrategy")

# Submit backtest
job = client.submit_backtest(
    strategy="PascalStrategy",
    pairs=["BTC/USDT:USDT"],
    timeframe="1h",
    start_date="2024-01-01",
    end_date="2024-01-31"
)

Response Types

from tradai.common.clients import (
    SyncResponse,
    FreshnessResponse,
    SymbolFreshness,
    StrategyListResponse,
)

# Type-safe response handling
freshness: FreshnessResponse = client.check_freshness(...)
for symbol in freshness.stale_symbols:
    print(f"{symbol.symbol}: last_updated={symbol.last_updated}")

Validation Module

Validation entities for dry-run and go-live deployments.

from tradai.common.validation import (
    DryRunValidationReport,
    GoLiveValidationReport,
    HeartbeatMetrics,
    CandleMetrics,
)

DryRunValidationReport

Report for validating strategy behavior in dry-run mode.

from tradai.common.validation import (
    DryRunValidationReport,
    HeartbeatMetrics,
    CandleMetrics,
    OrderMetrics,
)

heartbeat = HeartbeatMetrics(
    total_expected=1440,  # 24h at 1-min intervals
    total_received=1435,
    max_gap_minutes=5,
    uptime_pct=99.6
)

candle = CandleMetrics(
    total_candles=1440,
    missing_candles=2,
    delayed_candles=5,
    completeness_pct=99.5
)

report = DryRunValidationReport(
    strategy_id="pascal-v2",
    heartbeat_metrics=heartbeat,
    candle_metrics=candle,
    checks=[...],
)

if report.overall_passed:
    print("Ready for go-live!")
else:
    print(f"Failed checks: {report.failed_checks}")

GoLiveValidationReport

Report for validating infrastructure readiness.

from tradai.common.validation import (
    GoLiveValidationReport,
    GoLiveCheckResult,
    GoLiveCheckStatus,
    AlarmCheckResult,
    LambdaCheckResult,
)

alarm_check = AlarmCheckResult(
    alarm_name="high-drawdown-alarm",
    exists=True,
    enabled=True,
    status=GoLiveCheckStatus.PASSED
)

lambda_check = LambdaCheckResult(
    function_name="model-rollback",
    exists=True,
    memory_mb=256,
    timeout_seconds=120,
    status=GoLiveCheckStatus.PASSED
)

report = GoLiveValidationReport(
    strategy_id="pascal-v2",
    environment="prod",
    alarm_checks=[alarm_check],
    lambda_checks=[lambda_check],
)

print(f"Ready: {report.all_passed}")

Check Severities

Severity Impact
ERROR Blocks deployment
WARNING Review recommended
INFO Informational only

Complete Examples

Lambda Handler Pattern

Complete example of a Lambda handler with settings, state management, and metrics.

from datetime import datetime, UTC
from typing import Any

from tradai.common.lambda_ import (
    lambda_handler,
    LambdaContext,
    LambdaResponse,
    LambdaSettings,
    HealthState,
)
from tradai.common.aws import DynamoDBAdapter, MetricDatum, MetricUnit


class MyLambdaSettings(LambdaSettings):
    """Settings for my Lambda function."""
    service_discovery_namespace: str = "tradai.local"
    health_timeout_seconds: int = 30
    consecutive_failure_threshold: int = 3


@lambda_handler(MyLambdaSettings, cloudwatch_namespace_suffix="MyLambda")
def handler(event: dict, ctx: LambdaContext[MyLambdaSettings]) -> dict:
    """Process health check for services."""
    services = event.get("services", [])
    results = []

    # Initialize state repository
    state_repo = DynamoDBAdapter(table_name=ctx.settings.dynamodb_table_name)

    for service in services:
        # Check service health
        healthy, latency = check_service(service, ctx.settings.health_timeout_seconds)

        # Get/update state
        state_key = {"pk": f"HEALTH#{service['name']}"}
        existing = state_repo.get_item(state_key)

        if healthy:
            consecutive_failures = 0
        else:
            consecutive_failures = (existing.get("consecutive_failures", 0) + 1) if existing else 1

        # Save updated state
        state_repo.put_item({
            **state_key,
            "status": "healthy" if healthy else "unhealthy",
            "consecutive_failures": consecutive_failures,
            "last_check": datetime.now(UTC).isoformat(),
            "latency_ms": latency,
        })

        # Publish metrics
        ctx.metrics_publisher.publish([
            MetricDatum(
                name="ServiceHealthy",
                value=1 if healthy else 0,
                unit=MetricUnit.COUNT,
                dimensions={"Service": service["name"]}
            ),
            MetricDatum(
                name="HealthCheckLatency",
                value=latency,
                unit=MetricUnit.MILLISECONDS,
                dimensions={"Service": service["name"]}
            ),
        ])

        # Alert if threshold exceeded
        if consecutive_failures >= ctx.settings.consecutive_failure_threshold:
            ctx.alert_publisher.publish(
                subject=f"Service Unhealthy: {service['name']}",
                message=f"Service {service['name']} has failed {consecutive_failures} consecutive health checks",
            )

        results.append({
            "service": service["name"],
            "healthy": healthy,
            "latency_ms": latency,
        })

    return LambdaResponse.success(data={
        "summary": {"healthy": sum(1 for r in results if r["healthy"]), "total": len(results)},
        "results": results,
    }).to_dict()


def check_service(service: dict, timeout: int) -> tuple[bool, float]:
    """Check if service is healthy. Returns (healthy, latency_ms)."""
    import httpx
    import time

    url = f"http://{service['name']}.tradai.local:{service['port']}{service['path']}"
    start = time.time()

    try:
        response = httpx.get(url, timeout=timeout)
        latency = (time.time() - start) * 1000
        return response.status_code == 200, latency
    except Exception:
        latency = (time.time() - start) * 1000
        return False, latency

Service with Dependency Injection

Complete example of a FastAPI service using DI patterns.

from contextlib import asynccontextmanager
from typing import Annotated

from fastapi import FastAPI, Depends, HTTPException
from pydantic import BaseModel

from tradai.common import LoggerMixin
from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig


# Settings
class Settings(BaseModel):
    data_collection_url: str = "http://localhost:8002"
    debug: bool = False


# Service class
class BacktestService(LoggerMixin):
    """Business logic for backtesting."""

    def __init__(self, data_client: DataCollectionClient):
        self._data_client = data_client

    async def run_backtest(self, strategy: str, symbols: list[str], timeframe: str) -> dict:
        """Execute a backtest after ensuring data freshness."""
        self.logger.info(f"Running backtest for {strategy}", extra={"symbols": symbols})

        # Check data freshness
        freshness = self._data_client.check_freshness(symbols=symbols)
        if not freshness.all_fresh:
            self.logger.warning(f"Stale data detected: {freshness.stale_symbols}")
            # Trigger sync for stale symbols
            self._data_client.sync_data(
                symbols=[s.symbol for s in freshness.stale_symbols],
                start_date="2024-01-01",
                end_date="2024-12-31",
            )

        # Run backtest logic here...
        return {"strategy": strategy, "status": "completed"}


# Dependency injection
def get_settings() -> Settings:
    return Settings()


def get_data_client(settings: Annotated[Settings, Depends(get_settings)]) -> DataCollectionClient:
    config = DataCollectionClientConfig(base_url=settings.data_collection_url)
    return DataCollectionClient(config)


def get_backtest_service(
    data_client: Annotated[DataCollectionClient, Depends(get_data_client)]
) -> BacktestService:
    return BacktestService(data_client=data_client)


# FastAPI app
@asynccontextmanager
async def lifespan(app: FastAPI):
    yield


app = FastAPI(lifespan=lifespan)


@app.post("/api/v1/backtests")
async def create_backtest(
    strategy: str,
    symbols: list[str],
    timeframe: str = "1h",
    service: Annotated[BacktestService, Depends(get_backtest_service)] = None,
):
    """Submit a new backtest."""
    result = await service.run_backtest(strategy, symbols, timeframe)
    return result

Circuit Breaker Pattern

Using circuit breaker for resilient external calls.

from tradai.common import CircuitBreaker, CircuitBreakerConfig, CircuitBreakerError
from tradai.common.clients import DataCollectionClient, DataCollectionClientConfig


# Configure circuit breaker
circuit_config = CircuitBreakerConfig(
    failure_threshold=5,      # Open after 5 failures
    recovery_timeout=30.0,    # Try again after 30 seconds
    half_open_requests=2,     # Allow 2 test requests when half-open
)

circuit_breaker = CircuitBreaker(config=circuit_config)


async def fetch_with_circuit_breaker(symbols: list[str]) -> dict:
    """Fetch data with circuit breaker protection."""

    # Check if circuit is open
    if not circuit_breaker.can_execute():
        # Return cached/default data when circuit is open
        return {"source": "cache", "data": get_cached_data(symbols)}

    try:
        client = DataCollectionClient(
            DataCollectionClientConfig(base_url="http://data-collection:8002")
        )
        result = client.check_freshness(symbols=symbols)

        # Record success
        circuit_breaker.record_success()

        return {"source": "live", "data": result}

    except Exception as e:
        # Record failure
        circuit_breaker.record_failure()

        if circuit_breaker.is_open:
            # Circuit just opened - log and fall back
            print(f"Circuit breaker opened after failure: {e}")

        # Return cached data
        return {"source": "cache", "data": get_cached_data(symbols)}


def get_cached_data(symbols: list[str]) -> dict:
    """Return cached data as fallback."""
    return {"symbols": symbols, "cached": True}

See Also

Related SDKs:

Architecture:

Lambdas:

Services:

CLI: