fix: use SITE_URL and X-Forwarded headers for RSS URLs in reverse proxy

- Add get_base_url() helper function to detect HTTPS reverse proxy
- Prioritize SITE_URL env var over request.base_url
- Support X-Forwarded-Proto and X-Forwarded-Host headers
- Fixes RSS URL showing http:// instead of https:// behind reverse proxy

Fixes #6

Made-with: Cursor
This commit is contained in:
tmwgsicp 2026-03-24 23:38:44 +08:00
parent f9968a4e0d
commit ad62e8b8bb
2 changed files with 39 additions and 5 deletions

View File

@ -11,6 +11,7 @@ RSS 订阅路由
import csv import csv
import io import io
import os
import time import time
import logging import logging
from datetime import datetime, timezone from datetime import datetime, timezone
@ -28,6 +29,23 @@ from utils.image_proxy import proxy_image_url
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_base_url(request: Request) -> str:
"""
获取服务的基础 URL优先使用环境变量 SITE_URL
支持反向代理检测 X-Forwarded-Proto X-Forwarded-Host
"""
# 优先使用配置的 SITE_URL
site_url = os.getenv("SITE_URL", "").strip()
if site_url:
return site_url.rstrip("/")
# 检测反向代理头部
proto = request.headers.get("X-Forwarded-Proto", "http")
host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host", "localhost:5000")
return f"{proto}://{host}"
router = APIRouter() router = APIRouter()
@ -118,7 +136,7 @@ async def get_subscriptions(request: Request):
返回每个订阅的基本信息缓存文章数和 RSS 地址 返回每个订阅的基本信息缓存文章数和 RSS 地址
""" """
subs = rss_store.list_subscriptions() subs = rss_store.list_subscriptions()
base_url = str(request.base_url).rstrip("/") base_url = get_base_url(request)
items = [] items = []
for s in subs: for s in subs:
@ -195,7 +213,7 @@ async def get_aggregated_rss_feed(
articles = rss_store.get_all_articles(limit=limit) if subs else [] articles = rss_store.get_all_articles(limit=limit) if subs else []
base_url = str(request.base_url).rstrip("/") base_url = get_base_url(request)
xml = _build_aggregated_rss_xml(articles, nickname_map, base_url) xml = _build_aggregated_rss_xml(articles, nickname_map, base_url)
return Response( return Response(
content=xml, content=xml,
@ -218,7 +236,7 @@ async def export_subscriptions(
- **opml**: 标准 OPML 格式可直接导入 RSS 阅读器 - **opml**: 标准 OPML 格式可直接导入 RSS 阅读器
""" """
subs = rss_store.list_subscriptions() subs = rss_store.list_subscriptions()
base_url = str(request.base_url).rstrip("/") base_url = get_base_url(request)
if format == "opml": if format == "opml":
return _build_opml_response(subs, base_url) return _build_opml_response(subs, base_url)
@ -448,7 +466,7 @@ async def get_rss_feed(fakeid: str, request: Request,
raise HTTPException(status_code=404, detail="未找到该订阅,请先添加订阅") raise HTTPException(status_code=404, detail="未找到该订阅,请先添加订阅")
articles = rss_store.get_articles(fakeid, limit=limit) articles = rss_store.get_articles(fakeid, limit=limit)
base_url = str(request.base_url).rstrip("/") base_url = get_base_url(request)
xml = _build_rss_xml(fakeid, sub, articles, base_url) xml = _build_rss_xml(fakeid, sub, articles, base_url)
return Response( return Response(

View File

@ -8,6 +8,7 @@
搜索路由 - FastAPI版本 搜索路由 - FastAPI版本
""" """
import os
from fastapi import APIRouter, Query, Request from fastapi import APIRouter, Query, Request
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional, List from typing import Optional, List
@ -18,6 +19,21 @@ from utils.image_proxy import proxy_image_url
router = APIRouter() router = APIRouter()
def get_base_url(request: Request) -> str:
"""
获取服务的基础 URL优先使用环境变量 SITE_URL
支持反向代理检测 X-Forwarded-Proto X-Forwarded-Host
"""
site_url = os.getenv("SITE_URL", "").strip()
if site_url:
return site_url.rstrip("/")
proto = request.headers.get("X-Forwarded-Proto", "http")
host = request.headers.get("X-Forwarded-Host") or request.headers.get("Host", "localhost:5000")
return f"{proto}://{host}"
class Account(BaseModel): class Account(BaseModel):
"""公众号模型""" """公众号模型"""
id: str id: str
@ -80,7 +96,7 @@ async def search_accounts(query: str = Query(..., description="公众号名称
accounts = result.get("list", []) accounts = result.get("list", [])
# 获取 base_url 用于图片代理 # 获取 base_url 用于图片代理
base_url = str(request.base_url).rstrip("/") if request else "" base_url = get_base_url(request) if request else ""
# 格式化返回数据 # 格式化返回数据
formatted_accounts = [] formatted_accounts = []