diff options
Diffstat (limited to 'backend')
| -rw-r--r-- | backend/app/auth/__init__.py | 17 | ||||
| -rw-r--r-- | backend/app/auth/models.py | 41 | ||||
| -rw-r--r-- | backend/app/auth/routes.py | 222 | ||||
| -rw-r--r-- | backend/app/auth/utils.py | 73 | ||||
| -rw-r--r-- | backend/app/main.py | 33 | ||||
| -rw-r--r-- | backend/data/users.db | bin | 0 -> 20480 bytes | |||
| -rw-r--r-- | backend/requirements.txt | 6 |
7 files changed, 390 insertions, 2 deletions
diff --git a/backend/app/auth/__init__.py b/backend/app/auth/__init__.py new file mode 100644 index 0000000..8234b6f --- /dev/null +++ b/backend/app/auth/__init__.py @@ -0,0 +1,17 @@ +from .routes import router as auth_router +from .routes import get_current_user, get_current_user_optional +from .models import User, get_db, init_db +from .utils import Token, UserCreate, UserResponse + +__all__ = [ + 'auth_router', + 'get_current_user', + 'get_current_user_optional', + 'User', + 'get_db', + 'init_db', + 'Token', + 'UserCreate', + 'UserResponse', +] + diff --git a/backend/app/auth/models.py b/backend/app/auth/models.py new file mode 100644 index 0000000..76c33fa --- /dev/null +++ b/backend/app/auth/models.py @@ -0,0 +1,41 @@ +import os +from sqlalchemy import Column, Integer, String, DateTime, create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from datetime import datetime + +# Database configuration +DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data"))) +DATABASE_PATH = os.path.join(DATA_ROOT, "users.db") +DATABASE_URL = f"sqlite:///{DATABASE_PATH}" + +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) +Base = declarative_base() + + +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + username = Column(String(50), unique=True, index=True, nullable=False) + email = Column(String(100), unique=True, index=True, nullable=False) + hashed_password = Column(String(255), nullable=False) + created_at = Column(DateTime, default=datetime.utcnow) + is_active = Column(Integer, default=1) + + +def init_db(): + """Initialize database tables""" + os.makedirs(DATA_ROOT, exist_ok=True) + Base.metadata.create_all(bind=engine) + + +def get_db(): + """Dependency to get database session""" + db = SessionLocal() + try: + yield db + finally: + db.close() + diff --git a/backend/app/auth/routes.py b/backend/app/auth/routes.py new file mode 100644 index 0000000..7f07c2a --- /dev/null +++ b/backend/app/auth/routes.py @@ -0,0 +1,222 @@ +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from sqlalchemy.orm import Session +from typing import Optional + +from .models import User, get_db +from .utils import ( + Token, UserCreate, UserLogin, UserResponse, + verify_password, get_password_hash, create_access_token, decode_token +) + +router = APIRouter(prefix="/api/auth", tags=["Authentication"]) + +# OAuth2 scheme for token extraction +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login", auto_error=True) +oauth2_scheme_optional = OAuth2PasswordBearer(tokenUrl="/api/auth/login", auto_error=False) + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: Session = Depends(get_db) +) -> User: + """ + Dependency: Validate JWT token and return current user. + Raises 401 if token is invalid or user not found. + """ + username = decode_token(token) + if not username: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + + user = db.query(User).filter(User.username == username).first() + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="User not found", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled" + ) + + return user + + +async def get_current_user_optional( + token: Optional[str] = Depends(oauth2_scheme_optional), + db: Session = Depends(get_db) +) -> Optional[User]: + """ + Dependency: Try to get current user, but don't fail if not authenticated. + Returns None if no valid token. + """ + if not token: + return None + + username = decode_token(token) + if not username: + return None + + user = db.query(User).filter(User.username == username).first() + if not user or not user.is_active: + return None + + return user + + +@router.get("/check-username/{username}") +async def check_username(username: str, db: Session = Depends(get_db)): + """ + Check if a username is available. + """ + existing = db.query(User).filter(User.username == username).first() + return {"available": existing is None} + + +@router.get("/check-email/{email}") +async def check_email(email: str, db: Session = Depends(get_db)): + """ + Check if an email is available. + """ + existing = db.query(User).filter(User.email == email).first() + return {"available": existing is None} + + +@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED) +async def register(user_data: UserCreate, db: Session = Depends(get_db)): + """ + Register a new user account. + """ + # Check if username already exists + existing_user = db.query(User).filter(User.username == user_data.username).first() + if existing_user: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Username already registered" + ) + + # Check if email already exists + existing_email = db.query(User).filter(User.email == user_data.email).first() + if existing_email: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Email already registered" + ) + + # Validate password length + if len(user_data.password) < 6: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must be at least 6 characters" + ) + + # Create new user + user = User( + username=user_data.username, + email=user_data.email, + hashed_password=get_password_hash(user_data.password) + ) + db.add(user) + db.commit() + db.refresh(user) + + return user + + +@router.post("/login", response_model=Token) +async def login(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): + """ + Login with username and password, returns JWT token. + """ + # Find user by username + user = db.query(User).filter(User.username == form_data.username).first() + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not verify_password(form_data.password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled" + ) + + # Create access token + access_token = create_access_token(data={"sub": user.username}) + + return { + "access_token": access_token, + "token_type": "bearer", + "username": user.username + } + + +@router.post("/login/json", response_model=Token) +async def login_json(user_data: UserLogin, db: Session = Depends(get_db)): + """ + Login with JSON body (alternative to form-data). + """ + # Find user by username + user = db.query(User).filter(User.username == user_data.username).first() + + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + ) + + if not verify_password(user_data.password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="User account is disabled" + ) + + # Create access token + access_token = create_access_token(data={"sub": user.username}) + + return { + "access_token": access_token, + "token_type": "bearer", + "username": user.username + } + + +@router.get("/me", response_model=UserResponse) +async def get_me(current_user: User = Depends(get_current_user)): + """ + Get current authenticated user's info. + """ + return current_user + + +@router.post("/logout") +async def logout(): + """ + Logout endpoint (client should discard the token). + JWT tokens are stateless, so this is just for API completeness. + """ + return {"message": "Successfully logged out"} + diff --git a/backend/app/auth/utils.py b/backend/app/auth/utils.py new file mode 100644 index 0000000..5889279 --- /dev/null +++ b/backend/app/auth/utils.py @@ -0,0 +1,73 @@ +import os +import bcrypt +from datetime import datetime, timedelta +from typing import Optional +from jose import JWTError, jwt +from pydantic import BaseModel, EmailStr + +# Configuration - use environment variables in production +SECRET_KEY = os.getenv("JWT_SECRET_KEY", "contextflow-secret-key-change-in-production-2024") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("JWT_EXPIRE_MINUTES", "1440")) # 24 hours default + + +# Pydantic models for request/response +class Token(BaseModel): + access_token: str + token_type: str + username: str + + +class TokenData(BaseModel): + username: Optional[str] = None + + +class UserCreate(BaseModel): + username: str + email: EmailStr + password: str + + +class UserLogin(BaseModel): + username: str + password: str + + +class UserResponse(BaseModel): + id: int + username: str + email: str + created_at: datetime + is_active: int + + class Config: + from_attributes = True + + +def verify_password(plain_password: str, hashed_password: str) -> bool: + """Verify a password against its hash""" + return bcrypt.checkpw(plain_password.encode('utf-8'), hashed_password.encode('utf-8')) + + +def get_password_hash(password: str) -> str: + """Hash a password""" + return bcrypt.hashpw(password.encode('utf-8'), bcrypt.gensalt()).decode('utf-8') + + +def create_access_token(data: dict, expires_delta: Optional[timedelta] = None) -> str: + """Create a JWT access token""" + to_encode = data.copy() + expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + +def decode_token(token: str) -> Optional[str]: + """Decode a JWT token and return the username""" + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str = payload.get("sub") + return username + except JWTError: + return None + diff --git a/backend/app/main.py b/backend/app/main.py index a5f16af..902d693 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,13 +1,15 @@ import asyncio import tempfile import time -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, FileResponse from fastapi import UploadFile, File, Form from pydantic import BaseModel from app.schemas import NodeRunRequest, NodeRunResponse, MergeStrategy, Role, Message, Context, LLMConfig, ModelProvider, ReasoningEffort from app.services.llm import llm_streamer, generate_title, get_openai_client +from app.auth import auth_router, get_current_user, init_db, User, get_db +from app.auth.utils import get_password_hash from dotenv import load_dotenv import os import json @@ -15,11 +17,15 @@ import shutil from typing import List, Literal, Optional from uuid import uuid4 from google import genai +from sqlalchemy.orm import Session load_dotenv() app = FastAPI(title="ContextFlow Backend") +# Include authentication router +app.include_router(auth_router) + app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -28,6 +34,31 @@ app.add_middleware( allow_headers=["*"], ) +# Initialize database on startup +@app.on_event("startup") +async def startup_event(): + """Initialize database and create default test user if not exists""" + init_db() + + # Create test user if not exists + from app.auth.models import SessionLocal + db = SessionLocal() + try: + existing = db.query(User).filter(User.username == "test").first() + if not existing: + test_user = User( + username="test", + email="test@contextflow.local", + hashed_password=get_password_hash("114514") + ) + db.add(test_user) + db.commit() + print("[startup] Created default test user (test/114514)") + else: + print("[startup] Test user already exists") + finally: + db.close() + # --------- Project / Blueprint storage --------- DATA_ROOT = os.path.abspath(os.getenv("DATA_ROOT", os.path.join(os.getcwd(), "data"))) DEFAULT_USER = "test" diff --git a/backend/data/users.db b/backend/data/users.db Binary files differnew file mode 100644 index 0000000..9630889 --- /dev/null +++ b/backend/data/users.db diff --git a/backend/requirements.txt b/backend/requirements.txt index e340864..a9607fd 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -1,9 +1,13 @@ fastapi uvicorn -pydantic +pydantic[email] openai google-generativeai python-dotenv httpx python-multipart +# Authentication +python-jose[cryptography] +passlib[bcrypt] +sqlalchemy |
