mirror of
https://github.com/x1xhlol/system-prompts-and-models-of-ai-tools.git
synced 2026-06-18 15:29:36 +00:00
82 lines
2.9 KiB
Python
82 lines
2.9 KiB
Python
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from typing import Optional
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from uuid import UUID
|
|
from app.database import IS_SQLITE, get_db
|
|
from app.utils.security import decode_token
|
|
from app.models.user import User
|
|
from app.models.tenant import Tenant
|
|
|
|
security = HTTPBearer()
|
|
optional_security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
def _user_id_clause(user_id: str):
|
|
"""SQLite stores UUID PKs as str(36); Postgres uses native UUID — compare accordingly."""
|
|
uid = str(user_id)
|
|
return User.id == uid if IS_SQLITE else User.id == UUID(uid)
|
|
|
|
|
|
def _tenant_id_clause(tenant_id):
|
|
tid = str(tenant_id)
|
|
return Tenant.id == tid if IS_SQLITE else Tenant.id == UUID(tid)
|
|
|
|
|
|
async def get_optional_user(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(optional_security),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> Optional[User]:
|
|
if credentials is None:
|
|
return None
|
|
payload = decode_token(credentials.credentials)
|
|
if not payload or payload.get("type") != "access":
|
|
return None
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
return None
|
|
result = await db.execute(select(User).where(_user_id_clause(user_id), User.is_active == True))
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> User:
|
|
payload = decode_token(credentials.credentials)
|
|
if not payload or payload.get("type") != "access":
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or expired token")
|
|
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token payload")
|
|
|
|
result = await db.execute(select(User).where(_user_id_clause(user_id), User.is_active == True))
|
|
user = result.scalar_one_or_none()
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found or inactive")
|
|
|
|
return user
|
|
|
|
|
|
async def get_current_tenant(
|
|
current_user: User = Depends(get_current_user),
|
|
db: AsyncSession = Depends(get_db),
|
|
) -> Tenant:
|
|
result = await db.execute(
|
|
select(Tenant).where(_tenant_id_clause(current_user.tenant_id), Tenant.is_active == True)
|
|
)
|
|
tenant = result.scalar_one_or_none()
|
|
if not tenant:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Tenant not found or inactive")
|
|
return tenant
|
|
|
|
|
|
def require_role(*roles: str):
|
|
async def role_checker(current_user: User = Depends(get_current_user)):
|
|
if current_user.role not in roles:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Insufficient permissions")
|
|
return current_user
|
|
return role_checker
|