137 lines
4.9 KiB
Python
137 lines
4.9 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Optional, Tuple
|
|
|
|
import httpx
|
|
from sqlmodel import Session, select
|
|
|
|
from app.core.config import settings
|
|
from app.core.security import (
|
|
hash_password,
|
|
verify_password,
|
|
create_access_token,
|
|
create_refresh_token,
|
|
)
|
|
from app.models.models import User, RefreshToken
|
|
from app.services.user_service import get_user_by_email, create_user
|
|
|
|
|
|
def register(session: Session, email: str, password: str) -> Tuple[User, str, str]:
|
|
existing = get_user_by_email(session, email)
|
|
if existing:
|
|
raise ValueError("User already exists")
|
|
hashed = hash_password(password)
|
|
user = create_user(session, email, hashed)
|
|
access, _ = create_access_token(str(user.id))
|
|
refresh, exp = create_refresh_token(str(user.id))
|
|
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.now(timezone.utc), expires_at=exp)
|
|
session.add(rt)
|
|
session.commit()
|
|
return user, access, refresh
|
|
|
|
|
|
def login(session: Session, email: str, password: str) -> Tuple[User, str, str]:
|
|
user = get_user_by_email(session, email)
|
|
if not user or not user.hashed_password:
|
|
raise ValueError("Invalid credentials")
|
|
if not verify_password(user.hashed_password, password):
|
|
raise ValueError("Invalid credentials")
|
|
access, _ = create_access_token(str(user.id))
|
|
refresh, exp = create_refresh_token(str(user.id))
|
|
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.now(timezone.utc), expires_at=exp)
|
|
session.add(rt)
|
|
session.commit()
|
|
return user, access, refresh
|
|
|
|
|
|
def refresh_token(session: Session, token: str) -> str:
|
|
statement = select(RefreshToken).where(RefreshToken.token == token)
|
|
rt = session.exec(statement).first()
|
|
if not rt:
|
|
raise ValueError("Invalid refresh token")
|
|
access, _ = create_access_token(str(rt.user_id))
|
|
return access
|
|
|
|
|
|
def _create_or_get_user_from_oauth(session: Session, email: str) -> User:
|
|
user = get_user_by_email(session, email)
|
|
if user:
|
|
return user
|
|
# OAuth-only user: hashed_password is None
|
|
user = create_user(session, email, None)
|
|
return user
|
|
|
|
|
|
def handle_oauth_callback(session: Session, provider: str, code: str) -> Tuple[User, str, str]:
|
|
if provider == "github":
|
|
token_resp = _github_exchange_code(code)
|
|
access_token = token_resp.get("access_token")
|
|
if not access_token:
|
|
raise ValueError("Failed to obtain access token from GitHub")
|
|
# fetch emails
|
|
headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github+json"}
|
|
resp = httpx.get("https://api.github.com/user/emails", headers=headers, timeout=10.0)
|
|
resp.raise_for_status()
|
|
emails = resp.json()
|
|
primary = None
|
|
for e in emails:
|
|
if e.get("primary"):
|
|
primary = e.get("email")
|
|
break
|
|
if not primary and emails:
|
|
primary = emails[0].get("email")
|
|
if not primary:
|
|
raise ValueError("No email found from GitHub")
|
|
user = _create_or_get_user_from_oauth(session, primary)
|
|
elif provider == "google":
|
|
token_resp = _google_exchange_code(code)
|
|
access_token = token_resp.get("access_token")
|
|
if not access_token:
|
|
raise ValueError("Failed to obtain access token from Google")
|
|
resp = httpx.get(
|
|
"https://openidconnect.googleapis.com/v1/userinfo",
|
|
headers={"Authorization": f"Bearer {access_token}"},
|
|
timeout=10.0,
|
|
)
|
|
resp.raise_for_status()
|
|
profile = resp.json()
|
|
email = profile.get("email")
|
|
if not email:
|
|
raise ValueError("No email in Google profile")
|
|
user = _create_or_get_user_from_oauth(session, email)
|
|
else:
|
|
raise ValueError("Unsupported provider")
|
|
|
|
access, _ = create_access_token(str(user.id))
|
|
refresh, exp = create_refresh_token(str(user.id))
|
|
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.utcnow(), expires_at=exp)
|
|
session.add(rt)
|
|
session.commit()
|
|
return user, access, refresh
|
|
|
|
|
|
def _github_exchange_code(code: str) -> dict:
|
|
token_url = "https://github.com/login/oauth/access_token"
|
|
data = {
|
|
"client_id": settings.GITHUB_CLIENT_ID,
|
|
"client_secret": settings.GITHUB_CLIENT_SECRET,
|
|
"code": code,
|
|
}
|
|
resp = httpx.post(token_url, data=data, headers={"Accept": "application/json"}, timeout=10.0)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
def _google_exchange_code(code: str) -> dict:
|
|
token_url = "https://oauth2.googleapis.com/token"
|
|
data = {
|
|
"client_id": settings.GOOGLE_CLIENT_ID,
|
|
"client_secret": settings.GOOGLE_CLIENT_SECRET,
|
|
"code": code,
|
|
"grant_type": "authorization_code",
|
|
"redirect_uri": settings.GOOGLE_REDIRECT_URL,
|
|
}
|
|
resp = httpx.post(token_url, data=data, timeout=10.0)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|