1import asyncio
2from typing import TypeVar
3
4T = TypeVar('T')
5
6
7class AIError(Exception):
8 """Base exception for AI-related errors."""
9 pass
10
11class RateLimitError(AIError):
12 """Raised when API rate limit is hit."""
13 def __init__(self, retry_after: int = 60):
14 self.retry_after = retry_after
15 super().__init__(f"Rate limited. Retry after {retry_after}s")
16
17class APIError(AIError):
18 """Raised when API returns an error."""
19 def __init__(self, status_code: int, message: str):
20 self.status_code = status_code
21 super().__init__(f"API Error {status_code}: {message}")
22
23
24async def with_retry(
25 func,
26 max_retries: int = 3,
27 base_delay: float = 1.0
28):
29 """Execute function with retry logic."""
30 last_exception = None
31
32 for attempt in range(max_retries):
33 try:
34 return await func()
35 except RateLimitError as e:
36 print(f" Rate limited, waiting {e.retry_after}s...")
37 await asyncio.sleep(e.retry_after)
38 last_exception = e
39 except APIError as e:
40 if e.status_code >= 500:
41 delay = base_delay * (2 ** attempt)
42 print(f" Server error, retrying in {delay}s...")
43 await asyncio.sleep(delay)
44 last_exception = e
45 else:
46 raise
47 except Exception as e:
48 last_exception = e
49 delay = base_delay * (2 ** attempt)
50 print(f" Error: {e}, retrying in {delay}s...")
51 await asyncio.sleep(delay)
52
53 raise last_exception or Exception("Max retries exceeded")
54
55
56call_count = 0
57
58async def unreliable_api():
59 """Simulates an API that fails twice then succeeds."""
60 global call_count
61 call_count += 1
62
63 if call_count <= 2:
64 raise APIError(503, "Service temporarily unavailable")
65
66 return {"status": "success", "data": "Hello!"}
67
68
69async def main():
70 global call_count
71 call_count = 0
72
73 print("Calling unreliable API with retry:")
74 try:
75 result = await with_retry(unreliable_api)
76 print(f"✓ Success: {result}")
77 except Exception as e:
78 print(f"✗ Failed: {e}")
79
80asyncio.run(main())