80 lines
2.7 KiB
Python
80 lines
2.7 KiB
Python
from datetime import datetime
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from pydantic import EmailStr
|
|
from sqlmodel import Session
|
|
|
|
from app.api.deps import get_db
|
|
from app.schemas.schemas import UserCreate, Token
|
|
from app.services import auth_service
|
|
from app.models.models import RefreshToken
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/register", response_model=Token)
|
|
def register(data: UserCreate, db: Session = Depends(get_db)):
|
|
try:
|
|
user, access, refresh = auth_service.register(db, data.email, data.password)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
return {"access_token": access, "refresh_token": refresh}
|
|
|
|
|
|
@router.post("/login", response_model=Token)
|
|
def login(data: UserCreate, db: Session = Depends(get_db)):
|
|
try:
|
|
user, access, refresh = auth_service.login(db, data.email, data.password)
|
|
except ValueError:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
return {"access_token": access, "refresh_token": refresh}
|
|
|
|
|
|
@router.post("/refresh")
|
|
def refresh(body: dict, db: Session = Depends(get_db)):
|
|
token = body.get("refresh_token")
|
|
if not token:
|
|
raise HTTPException(status_code=400, detail="refresh_token required")
|
|
try:
|
|
access = auth_service.refresh_token(db, token)
|
|
except ValueError:
|
|
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
|
return {"access_token": access}
|
|
|
|
|
|
@router.get("/oauth/{provider}")
|
|
def oauth_start(provider: str):
|
|
if provider == "google":
|
|
from app.core.oauth import google_authorize_url
|
|
|
|
return {"auth_url": google_authorize_url()}
|
|
if provider == "github":
|
|
from app.core.oauth import github_authorize_url
|
|
|
|
return {"auth_url": github_authorize_url()}
|
|
raise HTTPException(status_code=404, detail="Unknown provider")
|
|
|
|
|
|
@router.get("/oauth/{provider}/callback")
|
|
def oauth_callback(provider: str, code: str | None = None, db: Session = Depends(get_db)):
|
|
if not code:
|
|
raise HTTPException(status_code=400, detail="Missing code")
|
|
try:
|
|
user, access, refresh = auth_service.handle_oauth_callback(db, provider, code)
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
return {"access_token": access, "refresh_token": refresh}
|
|
|
|
|
|
@router.post("/logout")
|
|
def logout(body: dict, db: Session = Depends(get_db)):
|
|
token = body.get("refresh_token")
|
|
if not token:
|
|
raise HTTPException(status_code=400, detail="refresh_token required")
|
|
# Invalidate refresh token: simple delete
|
|
statement = db.query(RefreshToken).filter(RefreshToken.token == token)
|
|
rt = statement.first()
|
|
if rt:
|
|
db.delete(rt)
|
|
db.commit()
|
|
return {"detail": "logged out"}
|