Add account-aware execution constraints
This commit is contained in:
190
scripts/db.py
190
scripts/db.py
@@ -48,15 +48,24 @@ def _table_columns(conn: sqlite3.Connection, table: str) -> list[str]:
|
||||
return []
|
||||
|
||||
|
||||
def _ensure_column(conn: sqlite3.Connection, table: str, column: str, ddl: str) -> None:
|
||||
columns = _table_columns(conn, table)
|
||||
if columns and column not in columns:
|
||||
conn.execute(f"ALTER TABLE {table} ADD COLUMN {ddl}")
|
||||
|
||||
|
||||
def _migrate_schema(conn: sqlite3.Connection) -> None:
|
||||
positions_cols = _table_columns(conn, "positions")
|
||||
if positions_cols and "watchlist_id" not in positions_cols:
|
||||
conn.execute("DROP TABLE positions")
|
||||
positions_cols = []
|
||||
|
||||
if positions_cols:
|
||||
_ensure_column(conn, "positions", "account_id", "account_id INTEGER")
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
with get_connection() as conn:
|
||||
_migrate_schema(conn)
|
||||
conn.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS watchlist (
|
||||
@@ -83,6 +92,31 @@ def init_db() -> None:
|
||||
CREATE INDEX IF NOT EXISTS idx_watchlist_market ON watchlist (market, code);
|
||||
CREATE INDEX IF NOT EXISTS idx_watchlist_is_watched ON watchlist (is_watched, updated_at DESC);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS accounts (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
market TEXT,
|
||||
currency TEXT,
|
||||
cash_balance REAL NOT NULL DEFAULT 0,
|
||||
available_cash REAL,
|
||||
note TEXT DEFAULT '',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_accounts_market_currency
|
||||
ON accounts (market, currency);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS stock_rules (
|
||||
code TEXT PRIMARY KEY,
|
||||
lot_size INTEGER,
|
||||
tick_size REAL,
|
||||
allows_odd_lot INTEGER NOT NULL DEFAULT 0,
|
||||
source TEXT DEFAULT 'manual',
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (code) REFERENCES watchlist(code) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS kline_daily (
|
||||
code TEXT NOT NULL,
|
||||
trade_date TEXT NOT NULL,
|
||||
@@ -102,13 +136,15 @@ def init_db() -> None:
|
||||
|
||||
CREATE TABLE IF NOT EXISTS positions (
|
||||
watchlist_id INTEGER PRIMARY KEY,
|
||||
account_id INTEGER,
|
||||
buy_price REAL NOT NULL,
|
||||
shares INTEGER NOT NULL,
|
||||
buy_date TEXT,
|
||||
note TEXT DEFAULT '',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
FOREIGN KEY (watchlist_id) REFERENCES watchlist(id) ON DELETE CASCADE
|
||||
FOREIGN KEY (watchlist_id) REFERENCES watchlist(id) ON DELETE CASCADE,
|
||||
FOREIGN KEY (account_id) REFERENCES accounts(id) ON DELETE SET NULL
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS analysis_cache (
|
||||
@@ -142,6 +178,7 @@ def init_db() -> None:
|
||||
ON aux_cache (code, category, created_at DESC);
|
||||
"""
|
||||
)
|
||||
_migrate_schema(conn)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@@ -408,6 +445,107 @@ def set_watch_status(code: str, watched: bool) -> dict | None:
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def list_accounts() -> list[dict]:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
rows = conn.execute(
|
||||
"""
|
||||
SELECT id, name, market, currency, cash_balance, available_cash, note, created_at, updated_at
|
||||
FROM accounts
|
||||
ORDER BY market IS NULL, market, currency IS NULL, currency, name
|
||||
"""
|
||||
).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
|
||||
def get_account(identifier: int | str) -> dict | None:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
if isinstance(identifier, int) or (isinstance(identifier, str) and identifier.isdigit()):
|
||||
row = conn.execute("SELECT * FROM accounts WHERE id = ?", (int(identifier),)).fetchone()
|
||||
else:
|
||||
row = conn.execute("SELECT * FROM accounts WHERE name = ?", (identifier,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def upsert_account(
|
||||
*,
|
||||
name: str,
|
||||
market: str | None = None,
|
||||
currency: str | None = None,
|
||||
cash_balance: float | None = None,
|
||||
available_cash: float | None = None,
|
||||
note: str = "",
|
||||
) -> dict:
|
||||
init_db()
|
||||
now = _utc_now_iso()
|
||||
with get_connection() as conn:
|
||||
existing = conn.execute("SELECT * FROM accounts WHERE name = ?", (name,)).fetchone()
|
||||
created_at = existing["created_at"] if existing else now
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO accounts (name, market, currency, cash_balance, available_cash, note, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
market = COALESCE(excluded.market, accounts.market),
|
||||
currency = COALESCE(excluded.currency, accounts.currency),
|
||||
cash_balance = COALESCE(excluded.cash_balance, accounts.cash_balance),
|
||||
available_cash = COALESCE(excluded.available_cash, accounts.available_cash),
|
||||
note = CASE WHEN excluded.note = '' THEN accounts.note ELSE excluded.note END,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(
|
||||
name,
|
||||
market,
|
||||
currency,
|
||||
0 if cash_balance is None else cash_balance,
|
||||
available_cash,
|
||||
note,
|
||||
created_at,
|
||||
now,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute("SELECT * FROM accounts WHERE name = ?", (name,)).fetchone()
|
||||
return dict(row)
|
||||
|
||||
|
||||
def upsert_stock_rule(
|
||||
*,
|
||||
code: str,
|
||||
lot_size: int | None = None,
|
||||
tick_size: float | None = None,
|
||||
allows_odd_lot: bool = False,
|
||||
source: str = "manual",
|
||||
) -> dict:
|
||||
init_db()
|
||||
now = _utc_now_iso()
|
||||
with get_connection() as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO stock_rules (code, lot_size, tick_size, allows_odd_lot, source, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(code) DO UPDATE SET
|
||||
lot_size = COALESCE(excluded.lot_size, stock_rules.lot_size),
|
||||
tick_size = COALESCE(excluded.tick_size, stock_rules.tick_size),
|
||||
allows_odd_lot = excluded.allows_odd_lot,
|
||||
source = excluded.source,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(code, lot_size, tick_size, int(allows_odd_lot), source, now),
|
||||
)
|
||||
conn.commit()
|
||||
row = conn.execute("SELECT * FROM stock_rules WHERE code = ?", (code,)).fetchone()
|
||||
return dict(row)
|
||||
|
||||
|
||||
def get_stock_rule(code: str) -> dict | None:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
row = conn.execute("SELECT * FROM stock_rules WHERE code = ?", (code,)).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def get_latest_kline_date(code: str, adj_type: str = "qfq") -> str | None:
|
||||
init_db()
|
||||
with get_connection() as conn:
|
||||
@@ -504,6 +642,12 @@ def list_positions() -> list[dict]:
|
||||
"""
|
||||
SELECT
|
||||
p.watchlist_id,
|
||||
p.account_id,
|
||||
a.name AS account_name,
|
||||
a.market AS account_market,
|
||||
a.currency AS account_currency,
|
||||
a.cash_balance AS account_cash_balance,
|
||||
a.available_cash AS account_available_cash,
|
||||
w.code,
|
||||
w.market,
|
||||
w.name,
|
||||
@@ -513,9 +657,15 @@ def list_positions() -> list[dict]:
|
||||
p.buy_date,
|
||||
p.note,
|
||||
p.created_at AS added_at,
|
||||
p.updated_at
|
||||
p.updated_at,
|
||||
sr.lot_size,
|
||||
sr.tick_size,
|
||||
sr.allows_odd_lot,
|
||||
sr.source AS lot_rule_source
|
||||
FROM positions p
|
||||
JOIN watchlist w ON w.id = p.watchlist_id
|
||||
LEFT JOIN accounts a ON a.id = p.account_id
|
||||
LEFT JOIN stock_rules sr ON sr.code = w.code
|
||||
ORDER BY w.code ASC
|
||||
"""
|
||||
).fetchall()
|
||||
@@ -529,6 +679,12 @@ def get_position(code: str) -> dict | None:
|
||||
"""
|
||||
SELECT
|
||||
p.watchlist_id,
|
||||
p.account_id,
|
||||
a.name AS account_name,
|
||||
a.market AS account_market,
|
||||
a.currency AS account_currency,
|
||||
a.cash_balance AS account_cash_balance,
|
||||
a.available_cash AS account_available_cash,
|
||||
w.code,
|
||||
w.market,
|
||||
w.name,
|
||||
@@ -538,9 +694,15 @@ def get_position(code: str) -> dict | None:
|
||||
p.buy_date,
|
||||
p.note,
|
||||
p.created_at AS added_at,
|
||||
p.updated_at
|
||||
p.updated_at,
|
||||
sr.lot_size,
|
||||
sr.tick_size,
|
||||
sr.allows_odd_lot,
|
||||
sr.source AS lot_rule_source
|
||||
FROM positions p
|
||||
JOIN watchlist w ON w.id = p.watchlist_id
|
||||
LEFT JOIN accounts a ON a.id = p.account_id
|
||||
LEFT JOIN stock_rules sr ON sr.code = w.code
|
||||
WHERE w.code = ?
|
||||
""",
|
||||
(code,),
|
||||
@@ -557,6 +719,7 @@ def upsert_position(
|
||||
shares: int,
|
||||
buy_date: str | None,
|
||||
note: str = "",
|
||||
account_id: int | None = None,
|
||||
name: str | None = None,
|
||||
currency: str | None = None,
|
||||
meta: dict | None = None,
|
||||
@@ -574,21 +737,23 @@ def upsert_position(
|
||||
now = _utc_now_iso()
|
||||
with get_connection() as conn:
|
||||
existing = conn.execute(
|
||||
"SELECT created_at FROM positions WHERE watchlist_id = ?", (watch["id"],)
|
||||
"SELECT created_at, account_id FROM positions WHERE watchlist_id = ?", (watch["id"],)
|
||||
).fetchone()
|
||||
created_at = existing["created_at"] if existing else now
|
||||
account_id_value = account_id if account_id is not None else (existing["account_id"] if existing else None)
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO positions (watchlist_id, buy_price, shares, buy_date, note, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
INSERT INTO positions (watchlist_id, account_id, buy_price, shares, buy_date, note, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(watchlist_id) DO UPDATE SET
|
||||
account_id = excluded.account_id,
|
||||
buy_price = excluded.buy_price,
|
||||
shares = excluded.shares,
|
||||
buy_date = excluded.buy_date,
|
||||
note = excluded.note,
|
||||
updated_at = excluded.updated_at
|
||||
""",
|
||||
(watch["id"], buy_price, shares, buy_date, note, created_at, now),
|
||||
(watch["id"], account_id_value, buy_price, shares, buy_date, note, created_at, now),
|
||||
)
|
||||
conn.commit()
|
||||
return get_position(code)
|
||||
@@ -605,7 +770,13 @@ def remove_position(code: str) -> bool:
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def update_position_fields(code: str, price: float | None = None, shares: int | None = None, note: str | None = None) -> dict | None:
|
||||
def update_position_fields(
|
||||
code: str,
|
||||
price: float | None = None,
|
||||
shares: int | None = None,
|
||||
note: str | None = None,
|
||||
account_id: int | None = None,
|
||||
) -> dict | None:
|
||||
current = get_position(code)
|
||||
if not current:
|
||||
return None
|
||||
@@ -618,6 +789,7 @@ def update_position_fields(code: str, price: float | None = None, shares: int |
|
||||
shares=shares if shares is not None else current["shares"],
|
||||
buy_date=current.get("buy_date"),
|
||||
note=note if note is not None else current.get("note", ""),
|
||||
account_id=account_id if account_id is not None else current.get("account_id"),
|
||||
name=watch.get("name"),
|
||||
currency=watch.get("currency"),
|
||||
meta=json.loads(watch["meta_json"]) if watch.get("meta_json") else None,
|
||||
|
||||
Reference in New Issue
Block a user