268 lines
10 KiB
Python
268 lines
10 KiB
Python
"""
|
|
股票数据服务
|
|
负责:数据获取、缓存、持仓管理
|
|
"""
|
|
|
|
import yfinance as yf
|
|
import pandas as pd
|
|
from sqlalchemy.orm import Session
|
|
from datetime import datetime, timedelta
|
|
import json
|
|
import sys
|
|
import os
|
|
|
|
# 添加父目录到路径
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
|
|
from database import Position, StockData, AnalysisResult, TradeLog
|
|
from models import PositionCreate
|
|
|
|
class StockService:
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.cache_dir = os.path.join(os.path.dirname(__file__), '..', '..', 'data', 'cache')
|
|
os.makedirs(self.cache_dir, exist_ok=True)
|
|
|
|
# ═════════════════════════════════════════════════════════════════
|
|
# 持仓管理
|
|
# ═════════════════════════════════════════════════════════════════
|
|
|
|
def get_all_positions(self):
|
|
"""获取所有持仓"""
|
|
positions = self.db.query(Position).all()
|
|
# 更新实时价格
|
|
for pos in positions:
|
|
try:
|
|
quote = self.get_realtime_quote(pos.ticker)
|
|
pos.current_price = quote['price']
|
|
pos.market_value = pos.shares * pos.current_price
|
|
pos.pnl = pos.market_value - (pos.shares * pos.cost_price)
|
|
pos.pnl_percent = (pos.pnl / (pos.shares * pos.cost_price)) * 100
|
|
except:
|
|
pass
|
|
self.db.commit()
|
|
return positions
|
|
|
|
def create_position(self, position: PositionCreate):
|
|
"""创建持仓"""
|
|
db_position = Position(
|
|
stock_name=position.stock_name,
|
|
ticker=position.ticker,
|
|
shares=position.shares,
|
|
cost_price=position.cost_price,
|
|
strategy=position.strategy,
|
|
notes=position.notes
|
|
)
|
|
self.db.add(db_position)
|
|
self.db.commit()
|
|
self.db.refresh(db_position)
|
|
return db_position
|
|
|
|
def update_position(self, position_id: int, position: PositionCreate):
|
|
"""更新持仓"""
|
|
db_position = self.db.query(Position).filter(Position.id == position_id).first()
|
|
if not db_position:
|
|
raise ValueError("持仓不存在")
|
|
|
|
db_position.stock_name = position.stock_name
|
|
db_position.ticker = position.ticker
|
|
db_position.shares = position.shares
|
|
db_position.cost_price = position.cost_price
|
|
db_position.strategy = position.strategy
|
|
db_position.notes = position.notes
|
|
|
|
self.db.commit()
|
|
self.db.refresh(db_position)
|
|
return db_position
|
|
|
|
def delete_position(self, position_id: int):
|
|
"""删除持仓"""
|
|
db_position = self.db.query(Position).filter(Position.id == position_id).first()
|
|
if not db_position:
|
|
raise ValueError("持仓不存在")
|
|
self.db.delete(db_position)
|
|
self.db.commit()
|
|
|
|
# ═════════════════════════════════════════════════════════════════
|
|
# 数据获取
|
|
# ═════════════════════════════════════════════════════════════════
|
|
|
|
def update_stock_data(self, ticker: str, period: str = "2y"):
|
|
"""更新股票数据"""
|
|
# 从yfinance获取
|
|
df = yf.download(ticker, period=period, auto_adjust=True, progress=False)
|
|
if df.empty:
|
|
raise ValueError(f"无法获取{ticker}的数据")
|
|
|
|
if isinstance(df.columns, pd.MultiIndex):
|
|
df.columns = df.columns.droplevel(1)
|
|
|
|
# 计算技术指标
|
|
df['MA5'] = df['Close'].rolling(5).mean()
|
|
df['MA20'] = df['Close'].rolling(20).mean()
|
|
df['MA60'] = df['Close'].rolling(60).mean()
|
|
|
|
# RSI
|
|
delta = df['Close'].diff()
|
|
gain = delta.clip(lower=0).ewm(alpha=1/14).mean()
|
|
loss = (-delta.clip(upper=0)).ewm(alpha=1/14).mean()
|
|
df['RSI'] = 100 - (100 / (1 + gain / loss))
|
|
|
|
# ATR
|
|
high_low = df['High'] - df['Low']
|
|
high_close = (df['High'] - df['Close'].shift(1)).abs()
|
|
low_close = (df['Low'] - df['Close'].shift(1)).abs()
|
|
tr = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
|
|
df['ATR'] = tr.rolling(14).mean()
|
|
|
|
df = df.dropna()
|
|
|
|
# 保存到数据库
|
|
for date, row in df.iterrows():
|
|
date_str = date.strftime('%Y-%m-%d')
|
|
|
|
# 检查是否已存在
|
|
existing = self.db.query(StockData).filter(
|
|
StockData.ticker == ticker,
|
|
StockData.date == date_str
|
|
).first()
|
|
|
|
if existing:
|
|
existing.open_price = float(row['Open'])
|
|
existing.high_price = float(row['High'])
|
|
existing.low_price = float(row['Low'])
|
|
existing.close_price = float(row['Close'])
|
|
existing.volume = float(row['Volume'])
|
|
existing.ma5 = float(row['MA5'])
|
|
existing.ma20 = float(row['MA20'])
|
|
existing.ma60 = float(row['MA60'])
|
|
existing.rsi = float(row['RSI'])
|
|
existing.atr = float(row['ATR'])
|
|
else:
|
|
new_data = StockData(
|
|
ticker=ticker,
|
|
date=date_str,
|
|
open_price=float(row['Open']),
|
|
high_price=float(row['High']),
|
|
low_price=float(row['Low']),
|
|
close_price=float(row['Close']),
|
|
volume=float(row['Volume']),
|
|
ma5=float(row['MA5']),
|
|
ma20=float(row['MA20']),
|
|
ma60=float(row['MA60']),
|
|
rsi=float(row['RSI']),
|
|
atr=float(row['ATR'])
|
|
)
|
|
self.db.add(new_data)
|
|
|
|
self.db.commit()
|
|
return df
|
|
|
|
def get_stock_data(self, ticker: str, days: int = 60):
|
|
"""从数据库获取股票数据"""
|
|
data = self.db.query(StockData).filter(
|
|
StockData.ticker == ticker
|
|
).order_by(StockData.date.desc()).limit(days).all()
|
|
|
|
if not data:
|
|
return None
|
|
|
|
df = pd.DataFrame([{
|
|
'date': d.date,
|
|
'open': d.open_price,
|
|
'high': d.high_price,
|
|
'low': d.low_price,
|
|
'close': d.close_price,
|
|
'volume': d.volume,
|
|
'ma5': d.ma5,
|
|
'ma20': d.ma20,
|
|
'ma60': d.ma60,
|
|
'rsi': d.rsi,
|
|
'atr': d.atr
|
|
} for d in data])
|
|
|
|
return df.iloc[::-1] # 正序
|
|
|
|
def get_realtime_quote(self, ticker: str):
|
|
"""获取实时行情"""
|
|
stock = yf.Ticker(ticker)
|
|
info = stock.info
|
|
|
|
# 尝试获取实时价格
|
|
try:
|
|
hist = stock.history(period="1d")
|
|
if not hist.empty:
|
|
current_price = float(hist['Close'].iloc[-1])
|
|
prev_close = float(hist['Close'].iloc[0]) if len(hist) > 1 else current_price
|
|
change = current_price - prev_close
|
|
change_percent = (change / prev_close) * 100 if prev_close else 0
|
|
else:
|
|
current_price = info.get('currentPrice', 0)
|
|
prev_close = info.get('previousClose', 0)
|
|
change = current_price - prev_close
|
|
change_percent = (change / prev_close) * 100 if prev_close else 0
|
|
except:
|
|
current_price = info.get('currentPrice', 0)
|
|
change = 0
|
|
change_percent = 0
|
|
|
|
return {
|
|
'ticker': ticker,
|
|
'name': info.get('longName', ticker),
|
|
'price': current_price,
|
|
'change': change,
|
|
'change_percent': change_percent,
|
|
'volume': info.get('volume', 0),
|
|
'updated_at': datetime.now().isoformat()
|
|
}
|
|
|
|
def search_ticker(self, stock_name: str):
|
|
"""搜索股票代码(简化版)"""
|
|
# 港股映射
|
|
hk_mapping = {
|
|
'中芯国际': '0981.HK',
|
|
'平安好医生': '1833.HK',
|
|
'叮当健康': '9886.HK',
|
|
'中原建业': '9982.HK',
|
|
'阅文集团': '0772.HK',
|
|
'泰升集团': '0687.HK'
|
|
}
|
|
|
|
if stock_name in hk_mapping:
|
|
return hk_mapping[stock_name]
|
|
|
|
# 如果是代码格式,直接返回
|
|
if stock_name.endswith('.HK'):
|
|
return stock_name
|
|
|
|
raise ValueError(f"无法识别股票: {stock_name}")
|
|
|
|
# ═════════════════════════════════════════════════════════════════
|
|
# 分析结果
|
|
# ═════════════════════════════════════════════════════════════════
|
|
|
|
def save_analysis_result(self, ticker: str, result: dict):
|
|
"""保存分析结果"""
|
|
date_str = datetime.now().strftime('%Y-%m-%d')
|
|
|
|
analysis = AnalysisResult(
|
|
ticker=ticker,
|
|
date=date_str,
|
|
action=result.get('signal', {}).get('action', 'HOLD'),
|
|
score=result.get('signal', {}).get('score', 0),
|
|
confidence=result.get('signal', {}).get('confidence', 'LOW'),
|
|
full_data=result
|
|
)
|
|
self.db.add(analysis)
|
|
self.db.commit()
|
|
|
|
def get_latest_analysis(self, ticker: str):
|
|
"""获取最新分析"""
|
|
result = self.db.query(AnalysisResult).filter(
|
|
AnalysisResult.ticker == ticker
|
|
).order_by(AnalysisResult.created_at.desc()).first()
|
|
|
|
if result:
|
|
return result.full_data
|
|
return None
|