"""
Authentication module for handling API authentication.
"""
import base64
import hashlib
import logging
import os
import random
import string
import time
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlencode
import httpx
from pydantic import BaseModel, ConfigDict
logger = logging.getLogger(__name__)
[docs]
class AuthConfig(BaseModel):
"""Base authentication configuration."""
type: str
model_config = ConfigDict(extra="allow")
[docs]
class ApiKeyAuth(AuthConfig):
"""API Key authentication configuration."""
type: str = "api_key"
key: str
header_name: str = "X-API-Key"
in_header: bool = True
in_query: bool = False
query_param: Optional[str] = None
[docs]
class BearerAuth(AuthConfig):
"""Bearer token authentication configuration."""
type: str = "bearer"
token: str
[docs]
class BasicAuth(AuthConfig):
"""Basic authentication configuration."""
type: str = "basic"
username: str
password: str
[docs]
class OAuth2ClientCredentials(AuthConfig):
"""OAuth2 client credentials authentication configuration."""
type: str = "oauth2_client_credentials"
client_id: str
client_secret: str
token_url: str
scope: Optional[str] = None
token: Optional[str] = None
expires_at: Optional[int] = None
[docs]
class OAuth2PKCE(AuthConfig):
"""OAuth2 PKCE (Proof Key for Code Exchange) authentication configuration.
This flow is designed for public clients that cannot securely store a client secret,
such as single-page applications and mobile apps.
"""
type: str = "oauth2_pkce"
client_id: str
redirect_uri: str
authorization_url: str
token_url: str
scope: Optional[str] = None
code_verifier: Optional[str] = None # Random string used to generate the challenge
code_challenge: Optional[str] = None # SHA256 hash of code_verifier
authorization_code: Optional[str] = None
state: Optional[str] = None # Security state to prevent CSRF attacks
token: Optional[str] = None
refresh_token: Optional[str] = None
expires_at: Optional[int] = None
storage_key: Optional[str] = None # Key to store credentials in secure storage
[docs]
class OAuth2DeviceFlow(AuthConfig):
"""OAuth2 Device Flow authentication configuration.
This flow is designed for devices with limited input capabilities, such as
TVs, IoT devices, and CLI applications.
"""
type: str = "oauth2_device_flow"
client_id: str
client_secret: Optional[str] = None # Some providers require this
device_authorization_url: str
token_url: str
scope: Optional[str] = None
device_code: Optional[str] = None # Code returned by authorization request
user_code: Optional[str] = None # Code shown to the user
verification_uri: Optional[str] = None # URL where user enters the code
verification_uri_complete: Optional[str] = None # Direct URL with code embedded
token: Optional[str] = None
refresh_token: Optional[str] = None
expires_at: Optional[int] = None
interval: Optional[int] = None # Polling interval in seconds
storage_key: Optional[str] = None # Key to store credentials in secure storage
[docs]
class AuthManager:
"""
Manager for handling different types of authentication.
This class creates and manages authentication configurations for different APIs.
It supports:
- API Key
- Bearer Token
- Basic Auth
- OAuth2 Client Credentials
- OAuth2 PKCE (Proof Key for Code Exchange)
- OAuth2 Device Flow
It also integrates with the secure credential storage system to safely store
sensitive authentication information.
"""
def __init__(self, secure_storage=None) -> None:
logger.debug("Initialized AuthManager")
self.secure_storage = secure_storage
def _generate_random_string(self, length: int = 64) -> str:
"""Generate a random string for PKCE code verifier or state."""
chars = string.ascii_letters + string.digits + "-._~"
return "".join(random.choice(chars) for _ in range(length))
def _create_code_challenge(self, code_verifier: str) -> str:
"""Create a code challenge from a code verifier using SHA-256."""
# Hash the verifier using SHA-256
code_challenge = hashlib.sha256(code_verifier.encode()).digest()
# Base64-URL encode the hash
return base64.urlsafe_b64encode(code_challenge).decode().rstrip("=")
def _resolve_env_vars(self, value: Any) -> Any:
"""
Resolve environment variables in a string value.
Args:
value: Value that may contain environment variable references like ${VAR_NAME}
Returns:
Value with environment variables resolved
"""
if not isinstance(value, str):
return value
if value.startswith("${") and value.endswith("}"):
env_var = value[2:-1]
env_value = os.environ.get(env_var)
if env_value is None:
logger.warning(f"Environment variable {env_var} not found")
return value
return env_value
return value
def _resolve_env_vars_in_dict(self, config: Dict[str, Any]) -> Dict[str, Any]:
"""
Recursively resolve environment variables in a dictionary.
Args:
config: Dictionary that may contain environment variable references
Returns:
Dictionary with environment variables resolved
"""
resolved: Dict[str, Any] = {}
for key, value in config.items():
if isinstance(value, dict):
resolved[key] = self._resolve_env_vars_in_dict(value)
elif isinstance(value, str):
resolved[key] = self._resolve_env_vars(value)
else:
resolved[key] = value
return resolved
[docs]
def refresh_oauth2_token(
self, auth_config: OAuth2ClientCredentials
) -> OAuth2ClientCredentials:
"""
Refresh an OAuth2 token using client credentials flow.
Args:
auth_config: OAuth2 client credentials configuration
Returns:
Updated OAuth2ClientCredentials with new token and expiry
"""
try:
data = {
"grant_type": "client_credentials",
"client_id": auth_config.client_id,
"client_secret": auth_config.client_secret,
}
if auth_config.scope:
data["scope"] = auth_config.scope
response = httpx.post(
auth_config.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
token_data = response.json()
# Update token and expiry
token = token_data.get("access_token")
expires_in = token_data.get("expires_in", 3600) # Default 1 hour
expires_at = int(time.time() + expires_in)
logger.debug("OAuth2 token refreshed")
# Update secure storage if available
if (
self.secure_storage
and hasattr(auth_config, "storage_key")
and auth_config.storage_key
):
self.secure_storage.store_credential(
auth_config.storage_key,
{
"token": token,
"expires_at": expires_at,
"updated_at": int(time.time()),
},
)
return OAuth2ClientCredentials(
type=auth_config.type,
client_id=auth_config.client_id,
client_secret=auth_config.client_secret,
token_url=auth_config.token_url,
scope=auth_config.scope,
token=token,
expires_at=expires_at,
)
except Exception as e:
logger.error(f"Error refreshing OAuth2 token: {str(e)}")
raise
[docs]
def get_pkce_authorization_url(self, auth_config: OAuth2PKCE) -> str:
"""
Get the authorization URL for OAuth2 PKCE flow.
Args:
auth_config: OAuth2 PKCE configuration
Returns:
Authorization URL for the user to visit
"""
params = {
"client_id": auth_config.client_id,
"redirect_uri": auth_config.redirect_uri,
"response_type": "code",
"state": auth_config.state,
"code_challenge": auth_config.code_challenge,
"code_challenge_method": "S256",
}
if auth_config.scope:
params["scope"] = auth_config.scope
# Build the authorization URL
query_string = urlencode(params)
return f"{auth_config.authorization_url}?{query_string}"
[docs]
def complete_pkce_flow(
self, auth_config: OAuth2PKCE, authorization_code: str
) -> OAuth2PKCE:
"""
Complete the OAuth2 PKCE flow by exchanging the authorization code for tokens.
Args:
auth_config: OAuth2 PKCE configuration
authorization_code: Authorization code received from redirect
Returns:
Updated OAuth2PKCE with token information
"""
try:
data = {
"grant_type": "authorization_code",
"code": authorization_code,
"redirect_uri": auth_config.redirect_uri,
"client_id": auth_config.client_id,
"code_verifier": auth_config.code_verifier,
}
response = httpx.post(
auth_config.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
token_data = response.json()
# Extract token information
token = token_data.get("access_token")
refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600)
expires_at = int(time.time() + expires_in)
logger.debug("OAuth2 PKCE flow completed successfully")
# Update secure storage if available
if self.secure_storage and auth_config.storage_key:
self.secure_storage.store_credential(
auth_config.storage_key,
{
"token": token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"updated_at": int(time.time()),
},
)
return OAuth2PKCE(
type=auth_config.type,
client_id=auth_config.client_id,
redirect_uri=auth_config.redirect_uri,
authorization_url=auth_config.authorization_url,
token_url=auth_config.token_url,
scope=auth_config.scope,
code_verifier=auth_config.code_verifier,
code_challenge=auth_config.code_challenge,
state=auth_config.state,
authorization_code=authorization_code,
token=token,
refresh_token=refresh_token,
expires_at=expires_at,
storage_key=auth_config.storage_key,
)
except Exception as e:
logger.error(f"Error completing OAuth2 PKCE flow: {str(e)}")
raise
[docs]
def refresh_pkce_token(self, auth_config: OAuth2PKCE) -> OAuth2PKCE:
"""
Refresh an OAuth2 token obtained via PKCE flow.
Args:
auth_config: OAuth2 PKCE configuration with refresh token
Returns:
Updated OAuth2PKCE with new token information
"""
if not auth_config.refresh_token:
raise ValueError("Refresh token is required to refresh PKCE token")
try:
data = {
"grant_type": "refresh_token",
"refresh_token": auth_config.refresh_token,
"client_id": auth_config.client_id,
}
response = httpx.post(
auth_config.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
token_data = response.json()
# Extract token information
token = token_data.get("access_token")
refresh_token = token_data.get("refresh_token", auth_config.refresh_token)
expires_in = token_data.get("expires_in", 3600)
expires_at = int(time.time() + expires_in)
logger.debug("OAuth2 PKCE token refreshed")
# Update secure storage if available
if self.secure_storage and auth_config.storage_key:
self.secure_storage.store_credential(
auth_config.storage_key,
{
"token": token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"updated_at": int(time.time()),
},
)
return OAuth2PKCE(
type=auth_config.type,
client_id=auth_config.client_id,
redirect_uri=auth_config.redirect_uri,
authorization_url=auth_config.authorization_url,
token_url=auth_config.token_url,
scope=auth_config.scope,
code_verifier=auth_config.code_verifier,
code_challenge=auth_config.code_challenge,
state=auth_config.state,
authorization_code=auth_config.authorization_code,
token=token,
refresh_token=refresh_token,
expires_at=expires_at,
storage_key=auth_config.storage_key,
)
except Exception as e:
logger.error(f"Error refreshing OAuth2 PKCE token: {str(e)}")
raise
[docs]
def start_device_flow(self, auth_config: OAuth2DeviceFlow) -> OAuth2DeviceFlow:
"""
Start the OAuth2 device flow authentication process.
Args:
auth_config: OAuth2 Device Flow configuration
Returns:
Updated OAuth2DeviceFlow with device code information
"""
try:
data = {
"client_id": auth_config.client_id,
}
if auth_config.client_secret:
data["client_secret"] = auth_config.client_secret
if auth_config.scope:
data["scope"] = auth_config.scope
response = httpx.post(
auth_config.device_authorization_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
response.raise_for_status()
device_data = response.json()
# Extract device flow information
device_code = device_data.get("device_code")
user_code = device_data.get("user_code")
verification_uri = device_data.get("verification_uri") or device_data.get(
"verification_url"
)
verification_uri_complete = device_data.get("verification_uri_complete")
# expires_in not used in client; polling interval governs timing
interval = device_data.get("interval", 5) # Default 5 seconds
logger.debug("OAuth2 device flow started successfully")
return OAuth2DeviceFlow(
type=auth_config.type,
client_id=auth_config.client_id,
client_secret=auth_config.client_secret,
device_authorization_url=auth_config.device_authorization_url,
token_url=auth_config.token_url,
scope=auth_config.scope,
device_code=device_code,
user_code=user_code,
verification_uri=verification_uri,
verification_uri_complete=verification_uri_complete,
interval=interval,
storage_key=auth_config.storage_key,
)
except Exception as e:
logger.error(f"Error starting OAuth2 device flow: {str(e)}")
raise
[docs]
def poll_device_flow(
self, auth_config: OAuth2DeviceFlow
) -> Tuple[bool, Optional[OAuth2DeviceFlow]]:
"""
Poll for device flow completion and token retrieval.
Args:
auth_config: OAuth2 Device Flow configuration with device code
Returns:
Tuple of (completed, updated_config)
- completed: True if the flow is complete, False if still pending
- updated_config: Updated OAuth2DeviceFlow with token information if completed
"""
if not auth_config.device_code:
raise ValueError(
"Device code is required to poll for device flow completion"
)
try:
data = {
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
"device_code": auth_config.device_code,
"client_id": auth_config.client_id,
}
if auth_config.client_secret:
data["client_secret"] = auth_config.client_secret
response = httpx.post(
auth_config.token_url,
data=data,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
# Check if authorization is still pending
if response.status_code == 400:
error_data = response.json()
error = error_data.get("error")
if error == "authorization_pending":
logger.debug("Device flow authorization pending")
return False, None
if error == "slow_down":
logger.debug("Device flow polling too fast, slowing down")
# Increase interval for next poll
auth_config.interval = auth_config.interval + 5
return False, auth_config
# Handle other errors
response.raise_for_status()
# Authorization complete, extract tokens
token_data = response.json()
token = token_data.get("access_token")
refresh_token = token_data.get("refresh_token")
expires_in = token_data.get("expires_in", 3600) # Default 1 hour
expires_at = int(time.time() + expires_in)
logger.debug("OAuth2 device flow completed successfully")
# Update secure storage if available
if self.secure_storage and auth_config.storage_key:
self.secure_storage.store_credential(
auth_config.storage_key,
{
"token": token,
"refresh_token": refresh_token,
"expires_at": expires_at,
"updated_at": int(time.time()),
},
)
updated_config = OAuth2DeviceFlow(
type=auth_config.type,
client_id=auth_config.client_id,
client_secret=auth_config.client_secret,
device_authorization_url=auth_config.device_authorization_url,
token_url=auth_config.token_url,
scope=auth_config.scope,
device_code=auth_config.device_code,
user_code=auth_config.user_code,
verification_uri=auth_config.verification_uri,
verification_uri_complete=auth_config.verification_uri_complete,
token=token,
refresh_token=refresh_token,
expires_at=expires_at,
interval=auth_config.interval,
storage_key=auth_config.storage_key,
)
return True, updated_config
except httpx.HTTPStatusError as e:
# Handle other HTTP errors
logger.error(f"HTTP error during device flow polling: {str(e)}")
raise
except Exception as e:
logger.error(f"Error polling OAuth2 device flow: {str(e)}")
raise