-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Implement Sample Paper Generator with format detection and question generation #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
3c34532
ea8adb4
334514d
811d17c
8fd6005
8c0f79b
3dad1ef
dd3fee4
6cf4aeb
271a024
dc4b94a
7ed7b85
fde1188
c1eecaf
81da731
c8139f8
61f54fe
bce477e
cfc15d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| """add selected_resource_ids to chat_sessions | ||
|
|
||
| Revision ID: 983ad2983d5c | ||
| Revises: d3dcf6b6fe3f | ||
| Create Date: 2026-05-18 20:55:14.773697 | ||
|
|
||
| """ | ||
| from typing import Sequence, Union | ||
|
|
||
| from alembic import op | ||
| import sqlalchemy as sa | ||
|
|
||
|
|
||
| # revision identifiers, used by Alembic. | ||
| revision: str = '983ad2983d5c' | ||
| down_revision: Union[str, Sequence[str], None] = 'd3dcf6b6fe3f' | ||
| branch_labels: Union[str, Sequence[str], None] = None | ||
| depends_on: Union[str, Sequence[str], None] = None | ||
|
|
||
|
|
||
| def upgrade() -> None: | ||
| """Upgrade schema.""" | ||
| # ### commands auto generated by Alembic - please adjust! ### | ||
| op.add_column('chat_sessions', sa.Column('selected_resource_ids', sa.JSON(), nullable=True)) | ||
| # ### end Alembic commands ### | ||
|
|
||
|
|
||
| def downgrade() -> None: | ||
| """Downgrade schema.""" | ||
| # ### commands auto generated by Alembic - please adjust! ### | ||
| op.drop_column('chat_sessions', 'selected_resource_ids') | ||
| # ### end Alembic commands ### |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| from fastapi import APIRouter, Depends, HTTPException, status | ||
| from sqlalchemy.ext.asyncio import AsyncSession | ||
| from sqlalchemy import select | ||
| from typing import List | ||
| import uuid | ||
| import json | ||
| import logging | ||
|
|
||
| from ..database import get_db | ||
| from ..models.user import User | ||
| from ..models.resource import Resource | ||
| from ..models.paper import Paper, paper_resources | ||
| from ..models.paper_output import PaperOutput | ||
| from ..schemas.paper import ( | ||
| PaperCreate, PaperOut, PaperUpdate, | ||
| PaperOutputOut, PaperOutputToggle, | ||
| FormatDetectionRequest | ||
| ) | ||
| from .auth import get_current_user | ||
| from ..llm.client import open_router_client | ||
| from ..llm.prompts import DETECT_FORMAT_PROMPT | ||
| from ..config import settings | ||
| from arq import create_pool | ||
| from arq.connections import RedisSettings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| router = APIRouter(prefix="/papers", tags=["papers"]) | ||
|
|
||
| @router.post("/detect-format") | ||
| async def detect_format( | ||
| data: FormatDetectionRequest, | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| # Fetch the resource | ||
| result = await db.execute( | ||
| select(Resource).where(Resource.id == data.resource_id, Resource.user_id == current_user.id) | ||
| ) | ||
| resource = result.scalar_one_or_none() | ||
|
|
||
| if not resource: | ||
| raise HTTPException(status_code=404, detail="Resource not found") | ||
|
|
||
| if not resource.extracted_text: | ||
| raise HTTPException(status_code=400, detail="Resource has no extracted text. Please wait for processing.") | ||
|
|
||
| # Call LLM for format detection | ||
| messages = [ | ||
| {"role": "system", "content": "You are a document analyzer."}, | ||
| {"role": "user", "content": f"{DETECT_FORMAT_PROMPT}\n\nPaper Content:\n{resource.extracted_text[:10000]}"} | ||
| ] | ||
|
|
||
| try: | ||
| # Collect stream into full response | ||
| full_response = "" | ||
| async for chunk in open_router_client.stream_chat(messages): | ||
| full_response += chunk | ||
|
|
||
| # Parse JSON from response | ||
| # LLM might return markdown code blocks, strip them if present | ||
| clean_json = full_response.strip() | ||
| if clean_json.startswith("```json"): | ||
| clean_json = clean_json[7:] | ||
| if clean_json.endswith("```"): | ||
| clean_json = clean_json[:-3] | ||
|
|
||
| format_config = json.loads(clean_json) | ||
| return format_config | ||
| except Exception as e: | ||
| raise HTTPException(status_code=500, detail=f"Failed to detect format: {str(e)}") | ||
|
|
||
| @router.post("", response_model=PaperOut) | ||
| async def create_paper( | ||
| data: PaperCreate, | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| if not data.resources: | ||
| raise HTTPException( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| detail="At least one resource must be selected." | ||
| ) | ||
|
|
||
| # 2. Create Paper record | ||
| new_paper = Paper( | ||
| user_id=current_user.id, | ||
| title=data.title, | ||
| format_config=data.format_config or {}, | ||
| delivery_mode=data.delivery_mode, | ||
| status="pending" | ||
| ) | ||
| db.add(new_paper) | ||
| await db.flush() | ||
|
|
||
| # 3. Link Resources | ||
| for res_link in data.resources: | ||
| # Verify resource exists and belongs to user | ||
| res_result = await db.execute( | ||
| select(Resource).where(Resource.id == res_link.resource_id, Resource.user_id == current_user.id) | ||
| ) | ||
| if not res_result.scalar_one_or_none(): | ||
| raise HTTPException( | ||
| status_code=status.HTTP_400_BAD_REQUEST, | ||
| detail=f"Resource {res_link.resource_id} not found or unauthorized." | ||
| ) | ||
|
|
||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| # Insert into association table | ||
| await db.execute( | ||
| paper_resources.insert().values( | ||
| paper_id=new_paper.id, | ||
| resource_id=res_link.resource_id, | ||
| resource_role=res_link.role | ||
| ) | ||
| ) | ||
|
Comment on lines
+117
to
+136
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reject resources that aren't ready yet. This loop only checks ownership. A direct API caller can attach 🤖 Prompt for AI Agents |
||
|
|
||
| await db.commit() | ||
| await db.refresh(new_paper) | ||
|
|
||
| # 4. Enqueue background task | ||
| redis = await create_pool(RedisSettings.from_dsn(settings.REDIS_URL)) | ||
| try: | ||
| job = await redis.enqueue_job("generate_paper_task", str(new_paper.id)) | ||
| if job is None: | ||
| raise RuntimeError("Failed to enqueue generate_paper_task") | ||
| except Exception as e: | ||
| logger.error(f"Redis enqueue error: {e}") | ||
| new_paper.status = "failed" | ||
| await db.commit() | ||
| raise HTTPException( | ||
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | ||
| detail="Paper generation queued failed, please retry." | ||
| ) | ||
| finally: | ||
| await redis.close() | ||
|
Comment on lines
+141
to
+167
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Commit the Line 154 hands ARQ an uncommitted 🧰 Tools🪛 Ruff (0.15.13)[warning] 158-158: Do not catch blind exception: (BLE001) [warning] 162-165: Within an (B904) 🤖 Prompt for AI Agents |
||
|
|
||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| return new_paper | ||
|
|
||
| @router.get("", response_model=List[PaperOut]) | ||
| async def list_papers( | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| result = await db.execute( | ||
| select(Paper).where(Paper.user_id == current_user.id).order_by(Paper.created_at.desc()) | ||
| ) | ||
| return result.scalars().all() | ||
|
|
||
| @router.get("/{paper_id}", response_model=PaperOut) | ||
| async def get_paper( | ||
| paper_id: uuid.UUID, | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| result = await db.execute( | ||
| select(Paper).where(Paper.id == paper_id, Paper.user_id == current_user.id) | ||
| ) | ||
| paper = result.scalar_one_or_none() | ||
| if not paper: | ||
| raise HTTPException(status_code=404, detail="Paper not found") | ||
| return paper | ||
|
|
||
| @router.get("/{paper_id}/output", response_model=PaperOutputOut) | ||
| async def get_paper_output( | ||
| paper_id: uuid.UUID, | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| # Verify paper ownership | ||
| result = await db.execute( | ||
| select(Paper).where(Paper.id == paper_id, Paper.user_id == current_user.id) | ||
| ) | ||
| if not result.scalar_one_or_none(): | ||
| raise HTTPException(status_code=404, detail="Paper not found") | ||
|
|
||
| output_result = await db.execute( | ||
| select(PaperOutput).where(PaperOutput.paper_id == paper_id) | ||
| ) | ||
| output = output_result.scalar_one_or_none() | ||
| if not output: | ||
| raise HTTPException(status_code=404, detail="Paper output not yet generated") | ||
|
|
||
| return output | ||
|
|
||
| @router.patch("/{paper_id}/output", response_model=PaperOutputOut) | ||
| async def toggle_output_settings( | ||
| paper_id: uuid.UUID, | ||
| data: PaperOutputToggle, | ||
| db: AsyncSession = Depends(get_db), | ||
| current_user: User = Depends(get_current_user) | ||
| ): | ||
| # Verify paper ownership | ||
| result = await db.execute( | ||
| select(Paper).where(Paper.id == paper_id, Paper.user_id == current_user.id) | ||
| ) | ||
| if not result.scalar_one_or_none(): | ||
| raise HTTPException(status_code=404, detail="Paper not found") | ||
|
|
||
| output_result = await db.execute( | ||
| select(PaperOutput).where(PaperOutput.paper_id == paper_id) | ||
| ) | ||
| output = output_result.scalar_one_or_none() | ||
| if not output: | ||
| raise HTTPException(status_code=404, detail="Paper output not found") | ||
|
|
||
| if data.include_answers is not None: | ||
| output.include_answers = data.include_answers | ||
| if data.include_explanations is not None: | ||
| output.include_explanations = data.include_explanations | ||
|
|
||
| await db.commit() | ||
| await db.refresh(output) | ||
| return output | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -116,6 +116,9 @@ async def update_session( | |
|
|
||
| if data.title: | ||
| session.title = data.title | ||
| if data.selected_resource_ids is not None: | ||
| # Convert UUIDs to strings for JSON storage | ||
| session.selected_resource_ids = [str(rid) for rid in data.selected_resource_ids] | ||
|
|
||
| await db.commit() | ||
| await db.refresh(session) | ||
|
|
@@ -165,12 +168,22 @@ async def ask_question( | |
| sess_result = await db.execute( | ||
| select(ChatSession).where(ChatSession.id == data.session_id, ChatSession.user_id == current_user.id) | ||
| ) | ||
| if not sess_result.scalar_one_or_none(): | ||
| session = sess_result.scalar_one_or_none() | ||
| if not session: | ||
| raise HTTPException(status_code=404, detail="Chat session not found") | ||
| session_id = data.session_id | ||
|
|
||
| # Update persistent resource selection if changed | ||
| new_resource_ids = [str(rid) for rid in data.resource_ids] | ||
| if session.selected_resource_ids != new_resource_ids: | ||
| session.selected_resource_ids = new_resource_ids | ||
| else: | ||
| # Auto-create session if none provided | ||
| new_sess = ChatSession(user_id=current_user.id, title=data.content[:30] + "...") | ||
| new_sess = ChatSession( | ||
| user_id=current_user.id, | ||
| title=data.content[:30] + "...", | ||
| selected_resource_ids=[str(rid) for rid in data.resource_ids] | ||
|
Comment on lines
+176
to
+185
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Persist only the validated resource IDs.
Proposed fix- new_resource_ids = [str(rid) for rid in data.resource_ids]
+ new_resource_ids = [str(res.id) for res in resources]
if session.selected_resource_ids != new_resource_ids:
session.selected_resource_ids = new_resource_ids
@@
new_sess = ChatSession(
user_id=current_user.id,
title=data.content[:30] + "...",
- selected_resource_ids=[str(rid) for rid in data.resource_ids]
+ selected_resource_ids=[str(res.id) for res in resources]
)🤖 Prompt for AI Agents |
||
| ) | ||
| db.add(new_sess) | ||
| await db.flush() | ||
| session_id = new_sess.id | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid leaking internal exception text in API error responses.
At Line 68, returning
str(e)can expose provider/internal details. Return a generic message to clients and log the detailed error server-side.🧰 Tools
🪛 Ruff (0.15.12)
[warning] 67-67: Do not catch blind exception:
Exception(BLE001)
[warning] 68-68: Within an
exceptclause, raise exceptions withraise ... from errorraise ... from Noneto distinguish them from errors in exception handling(B904)
[warning] 68-68: Use explicit conversion flag
Replace with conversion flag
(RUF010)
🤖 Prompt for AI Agents