first commit
This commit is contained in:
25
app/api/deps.py
Normal file
25
app/api/deps.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.db.session import get_session
|
||||
from app.core.security import decode_token
|
||||
from app.models.models import User
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
||||
|
||||
|
||||
def get_db():
|
||||
yield from get_session()
|
||||
|
||||
|
||||
def get_current_user(token: str = Depends(oauth2_scheme), db: Session = Depends(get_db)) -> User:
|
||||
try:
|
||||
payload = decode_token(token)
|
||||
user_id = int(payload.get("sub"))
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
||||
user = db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
||||
return user
|
||||
1
app/api/routers/__init__.py
Normal file
1
app/api/routers/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from . import auth, users
|
||||
79
app/api/routers/auth.py
Normal file
79
app/api/routers/auth.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from pydantic import EmailStr
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.schemas.schemas import UserCreate, Token
|
||||
from app.services import auth_service
|
||||
from app.models.models import RefreshToken
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/register", response_model=Token)
|
||||
def register(data: UserCreate, db: Session = Depends(get_db)):
|
||||
try:
|
||||
user, access, refresh = auth_service.register(db, data.email, data.password)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"access_token": access, "refresh_token": refresh}
|
||||
|
||||
|
||||
@router.post("/login", response_model=Token)
|
||||
def login(data: UserCreate, db: Session = Depends(get_db)):
|
||||
try:
|
||||
user, access, refresh = auth_service.login(db, data.email, data.password)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
return {"access_token": access, "refresh_token": refresh}
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
def refresh(body: dict, db: Session = Depends(get_db)):
|
||||
token = body.get("refresh_token")
|
||||
if not token:
|
||||
raise HTTPException(status_code=400, detail="refresh_token required")
|
||||
try:
|
||||
access = auth_service.refresh_token(db, token)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
||||
return {"access_token": access}
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
def oauth_start(provider: str):
|
||||
if provider == "google":
|
||||
from app.core.oauth import google_authorize_url
|
||||
|
||||
return {"auth_url": google_authorize_url()}
|
||||
if provider == "github":
|
||||
from app.core.oauth import github_authorize_url
|
||||
|
||||
return {"auth_url": github_authorize_url()}
|
||||
raise HTTPException(status_code=404, detail="Unknown provider")
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}/callback")
|
||||
def oauth_callback(provider: str, code: str | None = None, db: Session = Depends(get_db)):
|
||||
if not code:
|
||||
raise HTTPException(status_code=400, detail="Missing code")
|
||||
try:
|
||||
user, access, refresh = auth_service.handle_oauth_callback(db, provider, code)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"access_token": access, "refresh_token": refresh}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(body: dict, db: Session = Depends(get_db)):
|
||||
token = body.get("refresh_token")
|
||||
if not token:
|
||||
raise HTTPException(status_code=400, detail="refresh_token required")
|
||||
# Invalidate refresh token: simple delete
|
||||
statement = db.query(RefreshToken).filter(RefreshToken.token == token)
|
||||
rt = statement.first()
|
||||
if rt:
|
||||
db.delete(rt)
|
||||
db.commit()
|
||||
return {"detail": "logged out"}
|
||||
12
app/api/routers/users.py
Normal file
12
app/api/routers/users.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlmodel import Session
|
||||
|
||||
from app.api.deps import get_current_user, get_db
|
||||
from app.schemas.schemas import UserRead
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserRead)
|
||||
def read_me(current_user=Depends(get_current_user)):
|
||||
return current_user
|
||||
34
app/core/config.py
Normal file
34
app/core/config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
from pydantic_settings import BaseSettings
|
||||
from pydantic import Field
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = {
|
||||
"env_file": ".env",
|
||||
"extra": "allow",
|
||||
}
|
||||
|
||||
# Read DATABASE_URL from environment via model_config; avoid using Field(..., env=...)
|
||||
DATABASE_URL: str = "sqlite:///./dev.db"
|
||||
JWT_SECRET: str
|
||||
ALGORITHM: str = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 15
|
||||
REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
GOOGLE_CLIENT_ID: str | None = None
|
||||
GOOGLE_CLIENT_SECRET: str | None = None
|
||||
GOOGLE_REDIRECT_URL: str | None = None
|
||||
|
||||
GITHUB_CLIENT_ID: str | None = None
|
||||
GITHUB_CLIENT_SECRET: str | None = None
|
||||
GITHUB_REDIRECT_URL: str | None = None
|
||||
|
||||
SERVER_NAME: str = "localhost:8000"
|
||||
DEBUG: bool = True
|
||||
|
||||
|
||||
settings = Settings()
|
||||
27
app/core/oauth.py
Normal file
27
app/core/oauth.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from urllib.parse import urlencode
|
||||
from typing import Dict
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
def google_authorize_url(state: str = "state") -> str:
|
||||
params = {
|
||||
"client_id": settings.GOOGLE_CLIENT_ID,
|
||||
"redirect_uri": settings.GOOGLE_REDIRECT_URL,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
return "https://accounts.google.com/o/oauth2/v2/auth?" + urlencode(params)
|
||||
|
||||
|
||||
def github_authorize_url(state: str = "state") -> str:
|
||||
params = {
|
||||
"client_id": settings.GITHUB_CLIENT_ID,
|
||||
"redirect_uri": settings.GITHUB_REDIRECT_URL,
|
||||
"scope": "user:email",
|
||||
"state": state,
|
||||
}
|
||||
return "https://github.com/login/oauth/authorize?" + urlencode(params)
|
||||
38
app/core/security.py
Normal file
38
app/core/security.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from argon2 import PasswordHasher
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
ph = PasswordHasher()
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return ph.hash(password)
|
||||
|
||||
|
||||
def verify_password(hash: str, password: str) -> bool:
|
||||
try:
|
||||
return ph.verify(hash, password)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def create_access_token(sub: str) -> Tuple[str, datetime]:
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
payload = {"sub": sub, "exp": expire}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.ALGORITHM)
|
||||
return token, expire
|
||||
|
||||
|
||||
def create_refresh_token(sub: str) -> Tuple[str, datetime]:
|
||||
expire = datetime.now(timezone.utc) + timedelta(days=settings.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
payload = {"sub": sub, "exp": expire}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.ALGORITHM)
|
||||
return token, expire
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
return jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.ALGORITHM])
|
||||
3
app/db/base.py
Normal file
3
app/db/base.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
__all__ = ["SQLModel"]
|
||||
10
app/db/session.py
Normal file
10
app/db/session.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from sqlmodel import create_engine, Session
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL, echo=False)
|
||||
|
||||
|
||||
def get_session() -> Session:
|
||||
with Session(engine) as session:
|
||||
yield session
|
||||
26
app/main.py
Normal file
26
app/main.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import engine
|
||||
from app.db.base import SQLModel
|
||||
from app.api.routers import auth, users
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(title="FastAPI Account System")
|
||||
app.include_router(auth.router, prefix="/auth", tags=["auth"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app) -> AsyncGenerator[None, None]:
|
||||
# For quick local runs create tables automatically. In production use Alembic migrations.
|
||||
SQLModel.metadata.create_all(engine)
|
||||
yield
|
||||
|
||||
app.router.lifespan_context = lifespan
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
20
app/models/models.py
Normal file
20
app/models/models.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlmodel import SQLModel, Field, Column, Integer, String, Boolean, DateTime, ForeignKey
|
||||
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
email: str = Field(sa_column=Column(String(length=255), unique=True))
|
||||
hashed_password: Optional[str] = Field(default=None)
|
||||
is_active: bool = Field(default=True)
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
|
||||
|
||||
class RefreshToken(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
user_id: int = Field(foreign_key="user.id")
|
||||
token: str
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
expires_at: datetime
|
||||
26
app/schemas/schemas.py
Normal file
26
app/schemas/schemas.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import EmailStr
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
|
||||
|
||||
class UserRead(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
email: EmailStr
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
3
app/services/__init__.py
Normal file
3
app/services/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import auth_service, user_service
|
||||
|
||||
__all__ = ["auth_service", "user_service"]
|
||||
136
app/services/auth_service.py
Normal file
136
app/services/auth_service.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import httpx
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.security import (
|
||||
hash_password,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
)
|
||||
from app.models.models import User, RefreshToken
|
||||
from app.services.user_service import get_user_by_email, create_user
|
||||
|
||||
|
||||
def register(session: Session, email: str, password: str) -> Tuple[User, str, str]:
|
||||
existing = get_user_by_email(session, email)
|
||||
if existing:
|
||||
raise ValueError("User already exists")
|
||||
hashed = hash_password(password)
|
||||
user = create_user(session, email, hashed)
|
||||
access, _ = create_access_token(str(user.id))
|
||||
refresh, exp = create_refresh_token(str(user.id))
|
||||
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.now(timezone.utc), expires_at=exp)
|
||||
session.add(rt)
|
||||
session.commit()
|
||||
return user, access, refresh
|
||||
|
||||
|
||||
def login(session: Session, email: str, password: str) -> Tuple[User, str, str]:
|
||||
user = get_user_by_email(session, email)
|
||||
if not user or not user.hashed_password:
|
||||
raise ValueError("Invalid credentials")
|
||||
if not verify_password(user.hashed_password, password):
|
||||
raise ValueError("Invalid credentials")
|
||||
access, _ = create_access_token(str(user.id))
|
||||
refresh, exp = create_refresh_token(str(user.id))
|
||||
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.now(timezone.utc), expires_at=exp)
|
||||
session.add(rt)
|
||||
session.commit()
|
||||
return user, access, refresh
|
||||
|
||||
|
||||
def refresh_token(session: Session, token: str) -> str:
|
||||
statement = select(RefreshToken).where(RefreshToken.token == token)
|
||||
rt = session.exec(statement).first()
|
||||
if not rt:
|
||||
raise ValueError("Invalid refresh token")
|
||||
access, _ = create_access_token(str(rt.user_id))
|
||||
return access
|
||||
|
||||
|
||||
def _create_or_get_user_from_oauth(session: Session, email: str) -> User:
|
||||
user = get_user_by_email(session, email)
|
||||
if user:
|
||||
return user
|
||||
# OAuth-only user: hashed_password is None
|
||||
user = create_user(session, email, None)
|
||||
return user
|
||||
|
||||
|
||||
def handle_oauth_callback(session: Session, provider: str, code: str) -> Tuple[User, str, str]:
|
||||
if provider == "github":
|
||||
token_resp = _github_exchange_code(code)
|
||||
access_token = token_resp.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError("Failed to obtain access token from GitHub")
|
||||
# fetch emails
|
||||
headers = {"Authorization": f"token {access_token}", "Accept": "application/vnd.github+json"}
|
||||
resp = httpx.get("https://api.github.com/user/emails", headers=headers, timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
emails = resp.json()
|
||||
primary = None
|
||||
for e in emails:
|
||||
if e.get("primary"):
|
||||
primary = e.get("email")
|
||||
break
|
||||
if not primary and emails:
|
||||
primary = emails[0].get("email")
|
||||
if not primary:
|
||||
raise ValueError("No email found from GitHub")
|
||||
user = _create_or_get_user_from_oauth(session, primary)
|
||||
elif provider == "google":
|
||||
token_resp = _google_exchange_code(code)
|
||||
access_token = token_resp.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError("Failed to obtain access token from Google")
|
||||
resp = httpx.get(
|
||||
"https://openidconnect.googleapis.com/v1/userinfo",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
timeout=10.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
profile = resp.json()
|
||||
email = profile.get("email")
|
||||
if not email:
|
||||
raise ValueError("No email in Google profile")
|
||||
user = _create_or_get_user_from_oauth(session, email)
|
||||
else:
|
||||
raise ValueError("Unsupported provider")
|
||||
|
||||
access, _ = create_access_token(str(user.id))
|
||||
refresh, exp = create_refresh_token(str(user.id))
|
||||
rt = RefreshToken(user_id=user.id, token=refresh, created_at=datetime.utcnow(), expires_at=exp)
|
||||
session.add(rt)
|
||||
session.commit()
|
||||
return user, access, refresh
|
||||
|
||||
|
||||
def _github_exchange_code(code: str) -> dict:
|
||||
token_url = "https://github.com/login/oauth/access_token"
|
||||
data = {
|
||||
"client_id": settings.GITHUB_CLIENT_ID,
|
||||
"client_secret": settings.GITHUB_CLIENT_SECRET,
|
||||
"code": code,
|
||||
}
|
||||
resp = httpx.post(token_url, data=data, headers={"Accept": "application/json"}, timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
def _google_exchange_code(code: str) -> dict:
|
||||
token_url = "https://oauth2.googleapis.com/token"
|
||||
data = {
|
||||
"client_id": settings.GOOGLE_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": settings.GOOGLE_REDIRECT_URL,
|
||||
}
|
||||
resp = httpx.post(token_url, data=data, timeout=10.0)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
17
app/services/user_service.py
Normal file
17
app/services/user_service.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Optional
|
||||
from sqlmodel import Session, select
|
||||
|
||||
from app.models.models import User
|
||||
|
||||
|
||||
def get_user_by_email(session: Session, email: str) -> Optional[User]:
|
||||
statement = select(User).where(User.email == email)
|
||||
return session.exec(statement).first()
|
||||
|
||||
|
||||
def create_user(session: Session, email: str, hashed_password: Optional[str]) -> User:
|
||||
user = User(email=email, hashed_password=hashed_password)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
session.refresh(user)
|
||||
return user
|
||||
Reference in New Issue
Block a user