"""Circuit Breaker — prevents cascading failures on external integrations. States: CLOSED (normal) -> OPEN (failing) -> HALF_OPEN (probing). When open, calls fail fast without hitting the external service. """ from __future__ import annotations import logging import time from enum import Enum from typing import Any, Callable, Coroutine, Optional logger = logging.getLogger("dealix.circuit_breaker") class CircuitState(str, Enum): CLOSED = "closed" OPEN = "open" HALF_OPEN = "half_open" class CircuitBreaker: """In-memory circuit breaker for external service calls.""" def __init__( self, name: str, failure_threshold: int = 5, recovery_timeout: float = 60.0, half_open_max_calls: int = 1, ): self.name = name self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout self.half_open_max_calls = half_open_max_calls self._state = CircuitState.CLOSED self._failure_count = 0 self._last_failure_time: float = 0.0 self._half_open_calls = 0 @property def state(self) -> CircuitState: if self._state == CircuitState.OPEN: if time.monotonic() - self._last_failure_time >= self.recovery_timeout: self._state = CircuitState.HALF_OPEN self._half_open_calls = 0 logger.info("CircuitBreaker[%s] OPEN -> HALF_OPEN", self.name) return self._state def record_success(self) -> None: if self._state == CircuitState.HALF_OPEN: self._state = CircuitState.CLOSED self._failure_count = 0 logger.info("CircuitBreaker[%s] HALF_OPEN -> CLOSED", self.name) elif self._state == CircuitState.CLOSED: self._failure_count = 0 def record_failure(self) -> None: self._failure_count += 1 self._last_failure_time = time.monotonic() if self._failure_count >= self.failure_threshold: self._state = CircuitState.OPEN logger.warning( "CircuitBreaker[%s] -> OPEN (failures=%d)", self.name, self._failure_count, ) async def call( self, func: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any, ) -> Any: current_state = self.state if current_state == CircuitState.OPEN: raise CircuitOpenError( f"Circuit {self.name} is OPEN — failing fast" ) if current_state == CircuitState.HALF_OPEN: if self._half_open_calls >= self.half_open_max_calls: raise CircuitOpenError( f"Circuit {self.name} HALF_OPEN — max probe calls reached" ) self._half_open_calls += 1 try: result = await func(*args, **kwargs) self.record_success() return result except Exception as exc: self.record_failure() raise exc def to_dict(self) -> dict: return { "name": self.name, "state": self.state.value, "failure_count": self._failure_count, "failure_threshold": self.failure_threshold, "recovery_timeout": self.recovery_timeout, } class CircuitOpenError(Exception): pass class CircuitBreakerRegistry: """Registry of named circuit breakers for external services.""" def __init__(self) -> None: self._breakers: dict[str, CircuitBreaker] = {} def get( self, name: str, failure_threshold: int = 5, recovery_timeout: float = 60.0, ) -> CircuitBreaker: if name not in self._breakers: self._breakers[name] = CircuitBreaker( name=name, failure_threshold=failure_threshold, recovery_timeout=recovery_timeout, ) return self._breakers[name] def all_states(self) -> dict[str, dict]: return {name: cb.to_dict() for name, cb in self._breakers.items()} registry = CircuitBreakerRegistry()