Add account-aware execution constraints

This commit is contained in:
root
2026-04-02 01:26:10 +08:00
parent fbe69ab70e
commit 5d2482c9ba
3 changed files with 514 additions and 141 deletions

View File

@@ -48,15 +48,24 @@ def _table_columns(conn: sqlite3.Connection, table: str) -> list[str]:
return []
def _ensure_column(conn: sqlite3.Connection, table: str, column: str, ddl: str) -> None:
columns = _table_columns(conn, table)
if columns and column not in columns:
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
def _migrate_schema(conn: sqlite3.Connection) -> None:
positions_cols = _table_columns(conn, "positions")
if positions_cols and "watchlist_id" not in positions_cols:
conn.execute("DROP TABLE positions")
positions_cols = []
if positions_cols:
_ensure_column(conn, "positions", "account_id", "account_id INTEGER")
def init_db() -> None:
with get_connection() as conn:
_migrate_schema(conn)
conn.executescript(
"""
CREATE TABLE IF NOT EXISTS watchlist (
@@ -83,6 +92,31 @@ def init_db() -> None:
CREATE INDEX IF NOT EXISTS idx_watchlist_market ON watchlist (market, code);
CREATE INDEX IF NOT EXISTS idx_watchlist_is_watched ON watchlist (is_watched, updated_at DESC);
CREATE TABLE IF NOT EXISTS accounts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
market TEXT,
currency TEXT,
cash_balance REAL NOT NULL DEFAULT 0,
available_cash REAL,
note TEXT DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_accounts_market_currency
ON accounts (market, currency);
CREATE TABLE IF NOT EXISTS stock_rules (
code TEXT PRIMARY KEY,
lot_size INTEGER,
tick_size REAL,
allows_odd_lot INTEGER NOT NULL DEFAULT 0,
source TEXT DEFAULT 'manual',
updated_at TEXT NOT NULL,
FOREIGN KEY (code) REFERENCES watchlist(code) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS kline_daily (
code TEXT NOT NULL,
trade_date TEXT NOT NULL,
@@ -102,13 +136,15 @@ def init_db() -> None:
CREATE TABLE IF NOT EXISTS positions (
watchlist_id INTEGER PRIMARY KEY,
account_id INTEGER,
buy_price REAL NOT NULL,
shares INTEGER NOT NULL,
buy_date TEXT,
note TEXT DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
FOREIGN KEY (watchlist_id) REFERENCES watchlist(id) ON DELETE CASCADE
FOREIGN KEY (watchlist_id) REFERENCES watchlist(id) ON DELETE CASCADE,
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE SET NULL
);
CREATE TABLE IF NOT EXISTS analysis_cache (
@@ -142,6 +178,7 @@ def init_db() -> None:
ON aux_cache (code, category, created_at DESC);
"""
)
_migrate_schema(conn)
conn.commit()
@@ -408,6 +445,107 @@ def set_watch_status(code: str, watched: bool) -> dict | None:
return dict(row) if row else None
def list_accounts() -> list[dict]:
init_db()
with get_connection() as conn:
rows = conn.execute(
"""
SELECT id, name, market, currency, cash_balance, available_cash, note, created_at, updated_at
FROM accounts
ORDER BY market IS NULL, market, currency IS NULL, currency, name
"""
).fetchall()
return [dict(row) for row in rows]
def get_account(identifier: int | str) -> dict | None:
init_db()
with get_connection() as conn:
if isinstance(identifier, int) or (isinstance(identifier, str) and identifier.isdigit()):
row = conn.execute("SELECT * FROM accounts WHERE id = ?", (int(identifier),)).fetchone()
else:
row = conn.execute("SELECT * FROM accounts WHERE name = ?", (identifier,)).fetchone()
return dict(row) if row else None
def upsert_account(
*,
name: str,
market: str | None = None,
currency: str | None = None,
cash_balance: float | None = None,
available_cash: float | None = None,
note: str = "",
) -> dict:
init_db()
now = _utc_now_iso()
with get_connection() as conn:
existing = conn.execute("SELECT * FROM accounts WHERE name = ?", (name,)).fetchone()
created_at = existing["created_at"] if existing else now
conn.execute(
"""
INSERT INTO accounts (name, market, currency, cash_balance, available_cash, note, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(name) DO UPDATE SET
market = COALESCE(excluded.market, accounts.market),
currency = COALESCE(excluded.currency, accounts.currency),
cash_balance = COALESCE(excluded.cash_balance, accounts.cash_balance),
available_cash = COALESCE(excluded.available_cash, accounts.available_cash),
note = CASE WHEN excluded.note = '' THEN accounts.note ELSE excluded.note END,
updated_at = excluded.updated_at
""",
(
name,
market,
currency,
0 if cash_balance is None else cash_balance,
available_cash,
note,
created_at,
now,
),
)
conn.commit()
row = conn.execute("SELECT * FROM accounts WHERE name = ?", (name,)).fetchone()
return dict(row)
def upsert_stock_rule(
*,
code: str,
lot_size: int | None = None,
tick_size: float | None = None,
allows_odd_lot: bool = False,
source: str = "manual",
) -> dict:
init_db()
now = _utc_now_iso()
with get_connection() as conn:
conn.execute(
"""
INSERT INTO stock_rules (code, lot_size, tick_size, allows_odd_lot, source, updated_at)
VALUES (?, ?, ?, ?, ?, ?)
ON CONFLICT(code) DO UPDATE SET
lot_size = COALESCE(excluded.lot_size, stock_rules.lot_size),
tick_size = COALESCE(excluded.tick_size, stock_rules.tick_size),
allows_odd_lot = excluded.allows_odd_lot,
source = excluded.source,
updated_at = excluded.updated_at
""",
(code, lot_size, tick_size, int(allows_odd_lot), source, now),
)
conn.commit()
row = conn.execute("SELECT * FROM stock_rules WHERE code = ?", (code,)).fetchone()
return dict(row)
def get_stock_rule(code: str) -> dict | None:
init_db()
with get_connection() as conn:
row = conn.execute("SELECT * FROM stock_rules WHERE code = ?", (code,)).fetchone()
return dict(row) if row else None
def get_latest_kline_date(code: str, adj_type: str = "qfq") -> str | None:
init_db()
with get_connection() as conn:
@@ -504,6 +642,12 @@ def list_positions() -> list[dict]:
"""
SELECT
p.watchlist_id,
p.account_id,
a.name AS account_name,
a.market AS account_market,
a.currency AS account_currency,
a.cash_balance AS account_cash_balance,
a.available_cash AS account_available_cash,
w.code,
w.market,
w.name,
@@ -513,9 +657,15 @@ def list_positions() -> list[dict]:
p.buy_date,
p.note,
p.created_at AS added_at,
p.updated_at
p.updated_at,
sr.lot_size,
sr.tick_size,
sr.allows_odd_lot,
sr.source AS lot_rule_source
FROM positions p
JOIN watchlist w ON w.id = p.watchlist_id
LEFT JOIN accounts a ON a.id = p.account_id
LEFT JOIN stock_rules sr ON sr.code = w.code
ORDER BY w.code ASC
"""
).fetchall()
@@ -529,6 +679,12 @@ def get_position(code: str) -> dict | None:
"""
SELECT
p.watchlist_id,
p.account_id,
a.name AS account_name,
a.market AS account_market,
a.currency AS account_currency,
a.cash_balance AS account_cash_balance,
a.available_cash AS account_available_cash,
w.code,
w.market,
w.name,
@@ -538,9 +694,15 @@ def get_position(code: str) -> dict | None:
p.buy_date,
p.note,
p.created_at AS added_at,
p.updated_at
p.updated_at,
sr.lot_size,
sr.tick_size,
sr.allows_odd_lot,
sr.source AS lot_rule_source
FROM positions p
JOIN watchlist w ON w.id = p.watchlist_id
LEFT JOIN accounts a ON a.id = p.account_id
LEFT JOIN stock_rules sr ON sr.code = w.code
WHERE w.code = ?
""",
(code,),
@@ -557,6 +719,7 @@ def upsert_position(
shares: int,
buy_date: str | None,
note: str = "",
account_id: int | None = None,
name: str | None = None,
currency: str | None = None,
meta: dict | None = None,
@@ -574,21 +737,23 @@ def upsert_position(
now = _utc_now_iso()
with get_connection() as conn:
existing = conn.execute(
"SELECT created_at FROM positions WHERE watchlist_id = ?", (watch["id"],)
"SELECT created_at, account_id FROM positions WHERE watchlist_id = ?", (watch["id"],)
).fetchone()
created_at = existing["created_at"] if existing else now
account_id_value = account_id if account_id is not None else (existing["account_id"] if existing else None)
conn.execute(
"""
INSERT INTO positions (watchlist_id, buy_price, shares, buy_date, note, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?)
INSERT INTO positions (watchlist_id, account_id, buy_price, shares, buy_date, note, created_at, updated_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(watchlist_id) DO UPDATE SET
account_id = excluded.account_id,
buy_price = excluded.buy_price,
shares = excluded.shares,
buy_date = excluded.buy_date,
note = excluded.note,
updated_at = excluded.updated_at
""",
(watch["id"], buy_price, shares, buy_date, note, created_at, now),
(watch["id"], account_id_value, buy_price, shares, buy_date, note, created_at, now),
)
conn.commit()
return get_position(code)
@@ -605,7 +770,13 @@ def remove_position(code: str) -> bool:
return cur.rowcount > 0
def update_position_fields(code: str, price: float | None = None, shares: int | None = None, note: str | None = None) -> dict | None:
def update_position_fields(
code: str,
price: float | None = None,
shares: int | None = None,
note: str | None = None,
account_id: int | None = None,
) -> dict | None:
current = get_position(code)
if not current:
return None
@@ -618,6 +789,7 @@ def update_position_fields(code: str, price: float | None = None, shares: int |
shares=shares if shares is not None else current["shares"],
buy_date=current.get("buy_date"),
note=note if note is not None else current.get("note", ""),
account_id=account_id if account_id is not None else current.get("account_id"),
name=watch.get("name"),
currency=watch.get("currency"),
meta=json.loads(watch["meta_json"]) if watch.get("meta_json") else None,