add cached news, event, and buzz layers
This commit is contained in:
@@ -15,7 +15,9 @@
|
||||
### 🧾 数据来源
|
||||
- **行情**:{行情源}
|
||||
- **K线**:{K线源}
|
||||
- **新闻**:{新闻源}
|
||||
- **事件**:{事件源}
|
||||
- **舆情雷达**:{舆情源}
|
||||
|
||||
---
|
||||
|
||||
@@ -73,8 +75,22 @@
|
||||
- **20日平均收益 / 胜率**:{20日平均收益}% / {20日胜率}%
|
||||
- **回撤代理**:{回撤代理}%
|
||||
|
||||
### 📰 新闻与舆情辅助
|
||||
- **新闻情绪**:{新闻情绪}
|
||||
- **新闻热度**:{新闻热度}
|
||||
- **舆情雷达**:{舆情等级}
|
||||
- **主要样本**:
|
||||
- {新闻1}
|
||||
- {新闻2}
|
||||
- {新闻3}
|
||||
|
||||
### 📌 近期关键事件
|
||||
- {事件1}
|
||||
- {事件2}
|
||||
- {事件3}
|
||||
|
||||
### 💡 分析总结
|
||||
{2-4句话的自然语言总结,至少包含:当前市场场景、操作建议、置信度、主要支撑/风险点。若历史验证样本不足,要明确提醒。}
|
||||
{2-4句话的自然语言总结,至少包含:当前市场场景、操作建议、置信度、主要支撑/风险点。若历史验证样本不足,要明确提醒。可在 1 句话内补充新闻/舆情/事件仅作辅助,不直接决定评级。}
|
||||
|
||||
> ⚠️ 以上分析仅供参考,不构成投资建议。投资有风险,入市需谨慎。
|
||||
|
||||
@@ -121,7 +137,10 @@
|
||||
## 模板使用说明
|
||||
|
||||
- 所有 `{占位符}` 根据脚本返回的 JSON 数据填充。
|
||||
- `{行情源}` / `{K线源}` 使用 `data_sources` 中的实际来源;若事件层尚未接入,`{事件源}` 填 `暂无`。
|
||||
- `{行情源}` / `{K线源}` / `{新闻源}` / `{事件源}` / `{舆情源}` 使用 `data_sources` 中的实际来源;若某层尚未接入,填 `暂无`。
|
||||
- `{新闻情绪}` / `{新闻热度}` / `{新闻1-3}` 来自 `news` 字段;若新闻抓取失败或为空,分别填 `暂无` / `低` / `暂无相关新闻`。
|
||||
- `{舆情等级}` 来自 `buzz.level`;当前为新闻驱动的热度雷达,不是社交媒体实时讨论量。
|
||||
- `{事件1-3}` 来自 `events.items`;美股优先使用 SEC 事件,其他市场先从新闻标题中提取关键事件。
|
||||
- 最终输出必须是标准 Markdown 正文,不要放进 ``` 代码块。
|
||||
- 优先使用短段落、项目符号、卡片式结构;除非用户明确要求,否则尽量不要使用宽表格。
|
||||
- Telegram 等 IM 场景下,优先保证手机端可读性,避免一行承载过多字段。
|
||||
|
||||
@@ -15,20 +15,26 @@ import sys
|
||||
import json
|
||||
import argparse
|
||||
import time
|
||||
import html
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import xml.etree.ElementTree as ET
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
from db import (
|
||||
ANALYSIS_CACHE_TTL_SECONDS,
|
||||
AUX_CACHE_TTL_SECONDS,
|
||||
clear_analysis_cache,
|
||||
get_cached_analysis,
|
||||
get_cached_aux,
|
||||
get_kline_df,
|
||||
get_latest_kline_date,
|
||||
init_db,
|
||||
set_cached_analysis,
|
||||
set_cached_aux,
|
||||
upsert_kline_df,
|
||||
upsert_watchlist_item,
|
||||
)
|
||||
@@ -36,12 +42,15 @@ except ImportError:
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent))
|
||||
from db import (
|
||||
ANALYSIS_CACHE_TTL_SECONDS,
|
||||
AUX_CACHE_TTL_SECONDS,
|
||||
clear_analysis_cache,
|
||||
get_cached_analysis,
|
||||
get_cached_aux,
|
||||
get_kline_df,
|
||||
get_latest_kline_date,
|
||||
init_db,
|
||||
set_cached_analysis,
|
||||
set_cached_aux,
|
||||
upsert_kline_df,
|
||||
upsert_watchlist_item,
|
||||
)
|
||||
@@ -433,40 +442,343 @@ def fetch_us_kline_yahoo(symbol: str, period: str = '6mo') -> pd.DataFrame:
|
||||
'5y': '5y',
|
||||
}
|
||||
url = f"https://query1.finance.yahoo.com/v8/finance/chart/{symbol}?range={range_map.get(period, '6mo')}&interval=1d&includePrePost=false"
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
with urllib.request.urlopen(req, timeout=20) as response:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
result = data.get('chart', {}).get('result', [])
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
result = result[0]
|
||||
timestamps = result.get('timestamp') or []
|
||||
quote = (result.get('indicators', {}).get('quote') or [{}])[0]
|
||||
opens = quote.get('open') or []
|
||||
highs = quote.get('high') or []
|
||||
lows = quote.get('low') or []
|
||||
closes = quote.get('close') or []
|
||||
volumes = quote.get('volume') or []
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
with urllib.request.urlopen(req, timeout=20) as response:
|
||||
data = json.loads(response.read().decode('utf-8'))
|
||||
|
||||
records = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
if i >= len(opens) or opens[i] is None or closes[i] is None or highs[i] is None or lows[i] is None:
|
||||
continue
|
||||
records.append({
|
||||
'Date': datetime.fromtimestamp(ts).strftime('%Y-%m-%d'),
|
||||
'Open': float(opens[i]),
|
||||
'Close': float(closes[i]),
|
||||
'Low': float(lows[i]),
|
||||
'High': float(highs[i]),
|
||||
'Volume': float(volumes[i] or 0),
|
||||
result = data.get('chart', {}).get('result', [])
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
result = result[0]
|
||||
timestamps = result.get('timestamp') or []
|
||||
quote = (result.get('indicators', {}).get('quote') or [{}])[0]
|
||||
opens = quote.get('open') or []
|
||||
highs = quote.get('high') or []
|
||||
lows = quote.get('low') or []
|
||||
closes = quote.get('close') or []
|
||||
volumes = quote.get('volume') or []
|
||||
|
||||
records = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
if i >= len(opens) or opens[i] is None or closes[i] is None or highs[i] is None or lows[i] is None:
|
||||
continue
|
||||
records.append({
|
||||
'Date': datetime.fromtimestamp(ts).strftime('%Y-%m-%d'),
|
||||
'Open': float(opens[i]),
|
||||
'Close': float(closes[i]),
|
||||
'Low': float(lows[i]),
|
||||
'High': float(highs[i]),
|
||||
'Volume': float(volumes[i] or 0),
|
||||
})
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
if not df.empty:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df.set_index('Date', inplace=True)
|
||||
return df
|
||||
except (urllib.error.URLError, json.JSONDecodeError, ValueError) as e:
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
time.sleep(RETRY_BASE_DELAY * (attempt + 1))
|
||||
else:
|
||||
raise Exception(f"获取 Yahoo K线失败: {e}")
|
||||
|
||||
return pd.DataFrame()
|
||||
|
||||
|
||||
POSITIVE_NEWS_KEYWORDS = [
|
||||
'beat', 'surge', 'upgrade', 'record', 'growth', 'profit', 'win', 'breakthrough', 'bullish',
|
||||
'超预期', '增长', '中标', '利好', '创新高', '回购', '增持', '扭亏', '突破', '上调'
|
||||
]
|
||||
NEGATIVE_NEWS_KEYWORDS = [
|
||||
'miss', 'lawsuit', 'probe', 'fraud', 'downgrade', 'slump', 'warning', 'decline', 'loss', 'risk',
|
||||
'诉讼', '调查', '减持', '亏损', '预警', '下滑', '利空', '处罚', '暴跌', '风险'
|
||||
]
|
||||
|
||||
|
||||
def build_news_query(code: str, quote: dict) -> tuple[str, str]:
|
||||
stock = normalize_stock_code(code)
|
||||
name = (quote.get('name') or stock['code']).strip()
|
||||
if stock['market'] in ('SH', 'SZ'):
|
||||
return f'{name} {stock["code"][2:]}', 'zh'
|
||||
if stock['market'] == 'HK':
|
||||
return f'{name} {stock["code"].replace(".HK", "")}', 'zh'
|
||||
return f'{name} {stock["code"]}', 'en'
|
||||
|
||||
|
||||
def score_news_title(title: str) -> int:
|
||||
lower = title.lower()
|
||||
score = 0
|
||||
for kw in POSITIVE_NEWS_KEYWORDS:
|
||||
if kw.lower() in lower:
|
||||
score += 1
|
||||
for kw in NEGATIVE_NEWS_KEYWORDS:
|
||||
if kw.lower() in lower:
|
||||
score -= 1
|
||||
return score
|
||||
|
||||
|
||||
def summarize_news_sentiment(items: list[dict]) -> dict:
|
||||
if not items:
|
||||
return {
|
||||
'label': '暂无',
|
||||
'heat': '低',
|
||||
'score': 0,
|
||||
'positive': 0,
|
||||
'negative': 0,
|
||||
'neutral': 0,
|
||||
}
|
||||
|
||||
pos = neg = neu = 0
|
||||
total_score = 0
|
||||
for item in items:
|
||||
score = score_news_title(item.get('title', ''))
|
||||
total_score += score
|
||||
if score > 0:
|
||||
pos += 1
|
||||
elif score < 0:
|
||||
neg += 1
|
||||
else:
|
||||
neu += 1
|
||||
|
||||
if total_score >= 2:
|
||||
label = '偏正面'
|
||||
elif total_score <= -2:
|
||||
label = '偏负面'
|
||||
else:
|
||||
label = '中性'
|
||||
|
||||
count = len(items)
|
||||
if count >= 10:
|
||||
heat = '高'
|
||||
elif count >= 5:
|
||||
heat = '中'
|
||||
else:
|
||||
heat = '低'
|
||||
|
||||
return {
|
||||
'label': label,
|
||||
'heat': heat,
|
||||
'score': total_score,
|
||||
'positive': pos,
|
||||
'negative': neg,
|
||||
'neutral': neu,
|
||||
}
|
||||
|
||||
|
||||
def fetch_google_news_rss(code: str, quote: dict, limit: int = 5) -> dict:
|
||||
query, lang = build_news_query(code, quote)
|
||||
if lang == 'zh':
|
||||
params = f'q={urllib.parse.quote(query)}&hl=zh-CN&gl=CN&ceid=CN:zh-Hans'
|
||||
else:
|
||||
params = f'q={urllib.parse.quote(query)}&hl=en-US&gl=US&ceid=US:en'
|
||||
url = f'https://news.google.com/rss/search?{params}'
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
try:
|
||||
req = urllib.request.Request(url, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
with urllib.request.urlopen(req, timeout=20) as response:
|
||||
xml_text = response.read().decode('utf-8', 'ignore')
|
||||
root = ET.fromstring(xml_text)
|
||||
channel = root.find('channel')
|
||||
rss_items = channel.findall('item') if channel is not None else []
|
||||
items = []
|
||||
for node in rss_items[:limit]:
|
||||
title = html.unescape((node.findtext('title') or '').strip())
|
||||
link = (node.findtext('link') or '').strip()
|
||||
pub_date = (node.findtext('pubDate') or '').strip()
|
||||
source = html.unescape((node.findtext('source') or '').strip())
|
||||
if not title:
|
||||
continue
|
||||
items.append({
|
||||
'title': title,
|
||||
'link': link,
|
||||
'published_at': pub_date,
|
||||
'source': source or 'Google News',
|
||||
})
|
||||
sentiment = summarize_news_sentiment(items)
|
||||
return {
|
||||
'query': query,
|
||||
'source': 'google-news-rss',
|
||||
'items': items,
|
||||
'sentiment': sentiment,
|
||||
}
|
||||
except Exception:
|
||||
if attempt < MAX_RETRIES - 1:
|
||||
time.sleep(RETRY_BASE_DELAY * (attempt + 1))
|
||||
else:
|
||||
return {
|
||||
'query': query,
|
||||
'source': 'google-news-rss',
|
||||
'items': [],
|
||||
'sentiment': summarize_news_sentiment([]),
|
||||
}
|
||||
|
||||
return {
|
||||
'query': query,
|
||||
'source': 'google-news-rss',
|
||||
'items': [],
|
||||
'sentiment': summarize_news_sentiment([]),
|
||||
}
|
||||
|
||||
|
||||
SEC_TICKER_MAP_CACHE = None
|
||||
EVENT_KEYWORDS = {
|
||||
'业绩': ['财报', '业绩', '盈利', '营收', '季报', '年报', 'earnings', 'revenue', 'profit'],
|
||||
'订单/中标': ['中标', '订单', '签约', '合作', 'contract', 'deal'],
|
||||
'资本动作': ['回购', '增持', '减持', '融资', '定增', 'buyback', 'offering'],
|
||||
'监管/风险': ['调查', '诉讼', '处罚', '风险', '警告', 'probe', 'lawsuit', 'fraud', 'warning'],
|
||||
'产品/技术': ['新品', '发布', '突破', '芯片', 'ai', 'launch', 'breakthrough'],
|
||||
}
|
||||
|
||||
|
||||
def classify_event_title(title: str) -> str:
|
||||
lower = title.lower()
|
||||
for label, keywords in EVENT_KEYWORDS.items():
|
||||
if any(kw.lower() in lower for kw in keywords):
|
||||
return label
|
||||
return '市场动态'
|
||||
|
||||
|
||||
def derive_events_from_news(news: dict, limit: int = 3) -> dict:
|
||||
items = []
|
||||
for item in (news.get('items') or [])[:limit]:
|
||||
items.append({
|
||||
'title': item.get('title', ''),
|
||||
'category': classify_event_title(item.get('title', '')),
|
||||
'source': item.get('source', 'Google News'),
|
||||
'published_at': item.get('published_at', ''),
|
||||
'link': item.get('link', ''),
|
||||
})
|
||||
return {
|
||||
'source': 'news-derived',
|
||||
'items': items,
|
||||
}
|
||||
|
||||
df = pd.DataFrame(records)
|
||||
if not df.empty:
|
||||
df['Date'] = pd.to_datetime(df['Date'])
|
||||
df.set_index('Date', inplace=True)
|
||||
return df
|
||||
|
||||
def fetch_sec_ticker_map() -> dict:
|
||||
global SEC_TICKER_MAP_CACHE
|
||||
if SEC_TICKER_MAP_CACHE is not None:
|
||||
return SEC_TICKER_MAP_CACHE
|
||||
req = urllib.request.Request(
|
||||
'https://www.sec.gov/files/company_tickers.json',
|
||||
headers={'User-Agent': 'Mozilla/5.0 stockbuddy@example.com'}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=20) as response:
|
||||
payload = json.loads(response.read().decode('utf-8'))
|
||||
mapping = {}
|
||||
for _, item in payload.items():
|
||||
ticker = (item.get('ticker') or '').upper()
|
||||
cik = str(item.get('cik_str') or '').strip()
|
||||
if ticker and cik:
|
||||
mapping[ticker] = cik.zfill(10)
|
||||
SEC_TICKER_MAP_CACHE = mapping
|
||||
return mapping
|
||||
|
||||
|
||||
def fetch_sec_events(code: str, quote: dict, limit: int = 3) -> dict:
|
||||
stock = normalize_stock_code(code)
|
||||
if stock['market'] != 'US':
|
||||
return {'source': 'sec', 'items': []}
|
||||
|
||||
ticker = stock['code'].upper()
|
||||
try:
|
||||
mapping = fetch_sec_ticker_map()
|
||||
cik = mapping.get(ticker)
|
||||
if not cik:
|
||||
return {'source': 'sec', 'items': []}
|
||||
req = urllib.request.Request(
|
||||
f'https://data.sec.gov/submissions/CIK{cik}.json',
|
||||
headers={'User-Agent': 'Mozilla/5.0 stockbuddy@example.com'}
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=20) as response:
|
||||
payload = json.loads(response.read().decode('utf-8'))
|
||||
recent = ((payload.get('filings') or {}).get('recent')) or {}
|
||||
forms = recent.get('form') or []
|
||||
dates = recent.get('filingDate') or []
|
||||
accessions = recent.get('accessionNumber') or []
|
||||
primary_docs = recent.get('primaryDocument') or []
|
||||
items = []
|
||||
for i in range(min(len(forms), len(dates), len(accessions), len(primary_docs), limit)):
|
||||
form = forms[i]
|
||||
filing_date = dates[i]
|
||||
acc = accessions[i].replace('-', '')
|
||||
doc = primary_docs[i]
|
||||
link = f'https://www.sec.gov/Archives/edgar/data/{int(cik)}/{acc}/{doc}' if doc else ''
|
||||
items.append({
|
||||
'title': f'SEC {form} filed on {filing_date}',
|
||||
'category': 'SEC申报',
|
||||
'source': 'SEC',
|
||||
'published_at': filing_date,
|
||||
'link': link,
|
||||
})
|
||||
return {'source': 'sec', 'items': items}
|
||||
except Exception:
|
||||
return {'source': 'sec', 'items': []}
|
||||
|
||||
|
||||
def build_buzz_radar(news: dict, events: dict) -> dict:
|
||||
sentiment = news.get('sentiment') or {}
|
||||
news_heat = sentiment.get('heat', '低')
|
||||
total_events = len(events.get('items') or [])
|
||||
score = 0
|
||||
if news_heat == '高':
|
||||
score += 2
|
||||
elif news_heat == '中':
|
||||
score += 1
|
||||
if sentiment.get('label') in ('偏正面', '偏负面'):
|
||||
score += 1
|
||||
if total_events >= 3:
|
||||
score += 1
|
||||
|
||||
if score >= 4:
|
||||
level = '过热'
|
||||
elif score >= 2:
|
||||
level = '升温'
|
||||
else:
|
||||
level = '正常'
|
||||
|
||||
return {
|
||||
'level': level,
|
||||
'news_heat': news_heat,
|
||||
'event_count': total_events,
|
||||
'sentiment': sentiment.get('label', '暂无'),
|
||||
'source': 'news-derived-buzz',
|
||||
}
|
||||
|
||||
|
||||
def build_event_layer(code: str, quote: dict, news: dict) -> dict:
|
||||
stock = normalize_stock_code(code)
|
||||
if stock['market'] == 'US':
|
||||
sec_events = fetch_sec_events(code, quote)
|
||||
if sec_events.get('items'):
|
||||
return sec_events
|
||||
return derive_events_from_news(news)
|
||||
|
||||
|
||||
def get_or_refresh_aux_layers(code: str, quote: dict, refresh: bool = False) -> dict:
|
||||
news = None if refresh else get_cached_aux(code, 'news')
|
||||
if not news:
|
||||
news = fetch_google_news_rss(code, quote)
|
||||
set_cached_aux(code, 'news', news, ttl_seconds=AUX_CACHE_TTL_SECONDS)
|
||||
|
||||
events = None if refresh else get_cached_aux(code, 'events')
|
||||
if not events:
|
||||
events = build_event_layer(code, quote, news)
|
||||
set_cached_aux(code, 'events', events, ttl_seconds=AUX_CACHE_TTL_SECONDS)
|
||||
|
||||
buzz = None if refresh else get_cached_aux(code, 'buzz')
|
||||
if not buzz:
|
||||
buzz = build_buzz_radar(news, events)
|
||||
set_cached_aux(code, 'buzz', buzz, ttl_seconds=AUX_CACHE_TTL_SECONDS)
|
||||
|
||||
return {
|
||||
'news': news,
|
||||
'events': events,
|
||||
'buzz': buzz,
|
||||
}
|
||||
|
||||
|
||||
def period_to_days(period: str) -> int:
|
||||
@@ -1040,7 +1352,15 @@ def analyze_stock(code: str, period: str = "6mo", use_cache: bool = True) -> dic
|
||||
if use_cache:
|
||||
cached = get_cached_analysis(full_code, period)
|
||||
if cached:
|
||||
print(f"📦 使用缓存数据 ({full_code}),缓存有效期 {ANALYSIS_CACHE_TTL}s", file=sys.stderr)
|
||||
aux = get_or_refresh_aux_layers(full_code, cached.get('fundamental', {}) | {'name': cached.get('fundamental', {}).get('company_name', '')}, refresh=False)
|
||||
cached['news'] = aux['news']
|
||||
cached['events'] = aux['events']
|
||||
cached['buzz'] = aux['buzz']
|
||||
cached.setdefault('data_sources', {})
|
||||
cached['data_sources']['news'] = aux['news'].get('source', 'google-news-rss')
|
||||
cached['data_sources']['event'] = aux['events'].get('source', 'news-derived')
|
||||
cached['data_sources']['buzz'] = aux['buzz'].get('source', 'news-derived-buzz')
|
||||
print(f"📦 使用缓存数据 ({full_code}),分析缓存有效期 {ANALYSIS_CACHE_TTL}s,辅助层缓存有效期 {AUX_CACHE_TTL_SECONDS}s", file=sys.stderr)
|
||||
return cached
|
||||
|
||||
result = {"code": full_code, "market": stock['market'], "analysis_time": datetime.now().isoformat(), "error": None}
|
||||
@@ -1054,6 +1374,9 @@ def analyze_stock(code: str, period: str = "6mo", use_cache: bool = True) -> dic
|
||||
result["data_sources"] = {
|
||||
"quote": quote.get("quote_source", "tencent"),
|
||||
"kline": None,
|
||||
"news": None,
|
||||
"event": None,
|
||||
"buzz": None,
|
||||
}
|
||||
|
||||
upsert_watchlist_item(
|
||||
@@ -1107,6 +1430,14 @@ def analyze_stock(code: str, period: str = "6mo", use_cache: bool = True) -> dic
|
||||
result["recommendation"] = generate_recommendation(technical, fundamental, current_price, hist, quote)
|
||||
result["signal_validation"] = backtest_current_signal(hist, period)
|
||||
|
||||
aux = get_or_refresh_aux_layers(full_code, quote, refresh=not use_cache)
|
||||
result["news"] = aux["news"]
|
||||
result["events"] = aux["events"]
|
||||
result["buzz"] = aux["buzz"]
|
||||
result["data_sources"]["news"] = aux["news"].get("source", "google-news-rss")
|
||||
result["data_sources"]["event"] = aux["events"].get("source", "news-derived")
|
||||
result["data_sources"]["buzz"] = aux["buzz"].get("source", "news-derived-buzz")
|
||||
|
||||
if result.get("error") is None:
|
||||
set_cached_analysis(full_code, period, result)
|
||||
|
||||
|
||||
101
scripts/db.py
101
scripts/db.py
@@ -7,7 +7,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -18,6 +18,8 @@ DATA_DIR = Path.home() / ".stockbuddy"
|
||||
DB_PATH = DATA_DIR / "stockbuddy.db"
|
||||
ANALYSIS_CACHE_TTL_SECONDS = 600
|
||||
ANALYSIS_CACHE_MAX_ROWS = 1000
|
||||
AUX_CACHE_TTL_SECONDS = 1800
|
||||
AUX_CACHE_MAX_ROWS = 2000
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
@@ -123,6 +125,21 @@ def init_db() -> None:
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_analysis_cache_code_period
|
||||
ON analysis_cache (code, period, created_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS aux_cache (
|
||||
cache_key TEXT PRIMARY KEY,
|
||||
code TEXT NOT NULL,
|
||||
category TEXT NOT NULL,
|
||||
result_json TEXT NOT NULL,
|
||||
expires_at TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_aux_cache_expires_at
|
||||
ON aux_cache (expires_at);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_aux_cache_code_category
|
||||
ON aux_cache (code, category, created_at DESC);
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
@@ -165,6 +182,84 @@ def clear_analysis_cache() -> int:
|
||||
return count
|
||||
|
||||
|
||||
def cleanup_aux_cache(conn: sqlite3.Connection | None = None) -> None:
|
||||
own_conn = conn is None
|
||||
conn = conn or get_connection()
|
||||
try:
|
||||
now = _utc_now_iso()
|
||||
conn.execute("DELETE FROM aux_cache WHERE expires_at <= ?", (now,))
|
||||
overflow = conn.execute(
|
||||
"SELECT COUNT(*) AS cnt FROM aux_cache"
|
||||
).fetchone()["cnt"] - AUX_CACHE_MAX_ROWS
|
||||
if overflow > 0:
|
||||
conn.execute(
|
||||
"""
|
||||
DELETE FROM aux_cache
|
||||
WHERE cache_key IN (
|
||||
SELECT cache_key
|
||||
FROM aux_cache
|
||||
ORDER BY created_at ASC
|
||||
LIMIT ?
|
||||
)
|
||||
""",
|
||||
(overflow,),
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
if own_conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def clear_aux_cache() -> int:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
count = conn.execute("SELECT COUNT(*) AS cnt FROM aux_cache").fetchone()["cnt"]
|
||||
conn.execute("DELETE FROM aux_cache")
|
||||
conn.commit()
|
||||
return count
|
||||
|
||||
|
||||
def get_cached_aux(code: str, category: str) -> dict | None:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
cleanup_aux_cache(conn)
|
||||
cache_key = f"{code}:{category}"
|
||||
row = conn.execute(
|
||||
"""
|
||||
SELECT result_json
|
||||
FROM aux_cache
|
||||
WHERE cache_key = ? AND expires_at > ?
|
||||
""",
|
||||
(cache_key, _utc_now_iso()),
|
||||
).fetchone()
|
||||
if not row:
|
||||
return None
|
||||
result = json.loads(row["result_json"])
|
||||
result["_from_cache"] = True
|
||||
return result
|
||||
|
||||
|
||||
def set_cached_aux(code: str, category: str, result: dict, ttl_seconds: int = AUX_CACHE_TTL_SECONDS) -> None:
|
||||
init_db()
|
||||
now = _utc_now_iso()
|
||||
expires_at = (datetime.utcnow() + timedelta(seconds=ttl_seconds)).replace(microsecond=0).isoformat()
|
||||
cache_key = f"{code}:{category}"
|
||||
with get_connection() as conn:
|
||||
cleanup_aux_cache(conn)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO aux_cache (cache_key, code, category, result_json, expires_at, created_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(cache_key) DO UPDATE SET
|
||||
result_json = excluded.result_json,
|
||||
expires_at = excluded.expires_at,
|
||||
created_at = excluded.created_at
|
||||
""",
|
||||
(cache_key, code, category, json.dumps(result, ensure_ascii=False), expires_at, now),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_cached_analysis(code: str, period: str) -> dict | None:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
@@ -188,9 +283,7 @@ def get_cached_analysis(code: str, period: str) -> dict | None:
|
||||
def set_cached_analysis(code: str, period: str, result: dict) -> None:
|
||||
init_db()
|
||||
now = _utc_now_iso()
|
||||
expires_at = datetime.utcfromtimestamp(
|
||||
datetime.utcnow().timestamp() + ANALYSIS_CACHE_TTL_SECONDS
|
||||
).replace(microsecond=0).isoformat()
|
||||
expires_at = (datetime.utcnow() + timedelta(seconds=ANALYSIS_CACHE_TTL_SECONDS)).replace(microsecond=0).isoformat()
|
||||
cache_key = f"{code}:{period}"
|
||||
with get_connection() as conn:
|
||||
cleanup_analysis_cache(conn)
|
||||
|
||||
Reference in New Issue
Block a user