feat(calibration): complete step 1 data inspection with data quality v1
This commit is contained in:
248
src/paper/broker.py
Normal file
248
src/paper/broker.py
Normal file
@@ -0,0 +1,248 @@
|
||||
# src/paper/broker.py
|
||||
"""
|
||||
Paper Broker (Spot, Long-only) con posiciones POR ESTRATEGIA.
|
||||
|
||||
- Mantiene cash, equity, pnl
|
||||
- Posiciones indexadas por position_id (ej: "BTC/USDT::MA_Crossover")
|
||||
- Cada Position guarda el symbol REAL ("BTC/USDT") para MTM correcto
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, asdict
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
@dataclass
|
||||
class PaperTrade:
|
||||
symbol: str # ticker real, e.g. "BTC/USDT"
|
||||
position_id: str # e.g. "BTC/USDT::MA_Crossover"
|
||||
side: str # "BUY" or "SELL"
|
||||
qty: float
|
||||
price: float # executed price (after slippage)
|
||||
fee: float
|
||||
notional: float
|
||||
realized_pnl: float
|
||||
timestamp: str # ISO8601
|
||||
meta: Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Position:
|
||||
symbol: str # ticker real
|
||||
position_id: str # clave lógica
|
||||
qty: float = 0.0
|
||||
avg_entry: float = 0.0
|
||||
|
||||
def is_open(self) -> bool:
|
||||
return self.qty > 0.0
|
||||
|
||||
|
||||
class PaperBroker:
|
||||
def __init__(
|
||||
self,
|
||||
initial_cash: float,
|
||||
commission_rate: float = 0.001,
|
||||
slippage_rate: float = 0.0005,
|
||||
):
|
||||
if initial_cash <= 0:
|
||||
raise ValueError("initial_cash must be > 0")
|
||||
|
||||
self.initial_cash = float(initial_cash)
|
||||
self.cash = float(initial_cash)
|
||||
|
||||
self.commission_rate = float(commission_rate)
|
||||
self.slippage_rate = float(slippage_rate)
|
||||
|
||||
# position_id -> Position
|
||||
self.positions: Dict[str, Position] = {}
|
||||
|
||||
# symbol -> last price
|
||||
self.last_price: Dict[str, float] = {}
|
||||
|
||||
self.realized_pnl: float = 0.0
|
||||
self.trades: List[PaperTrade] = []
|
||||
|
||||
# -----------------------------
|
||||
# Pricing / MTM
|
||||
# -----------------------------
|
||||
def update_price(self, symbol: str, price: float) -> None:
|
||||
if price <= 0:
|
||||
raise ValueError("price must be > 0")
|
||||
self.last_price[symbol] = float(price)
|
||||
|
||||
def get_position(self, position_id: str, symbol: Optional[str] = None) -> Position:
|
||||
"""
|
||||
Devuelve una Position por position_id.
|
||||
Si no existe, crea una nueva (requiere symbol para inicializar correctamente).
|
||||
"""
|
||||
if position_id not in self.positions:
|
||||
if symbol is None:
|
||||
raise ValueError("symbol is required to create a new position")
|
||||
self.positions[position_id] = Position(symbol=symbol, position_id=position_id)
|
||||
return self.positions[position_id]
|
||||
|
||||
def get_unrealized_pnl(self, position_id: str) -> float:
|
||||
pos = self.positions.get(position_id)
|
||||
if not pos or not pos.is_open():
|
||||
return 0.0
|
||||
lp = self.last_price.get(pos.symbol)
|
||||
if lp is None:
|
||||
return 0.0
|
||||
return (lp - pos.avg_entry) * pos.qty
|
||||
|
||||
def get_equity(self) -> float:
|
||||
"""
|
||||
Equity = cash + sum(pos.qty * last_price[pos.symbol])
|
||||
"""
|
||||
equity = self.cash
|
||||
for _, pos in self.positions.items():
|
||||
if pos.is_open():
|
||||
lp = self.last_price.get(pos.symbol)
|
||||
if lp is not None:
|
||||
equity += pos.qty * lp
|
||||
else:
|
||||
equity += pos.qty * pos.avg_entry
|
||||
return float(equity)
|
||||
|
||||
# -----------------------------
|
||||
# Execution
|
||||
# -----------------------------
|
||||
def _now_iso(self) -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
def _apply_slippage(self, side: str, price: float) -> float:
|
||||
s = side.upper()
|
||||
if s == "BUY":
|
||||
return price * (1.0 + self.slippage_rate)
|
||||
elif s == "SELL":
|
||||
return price * (1.0 - self.slippage_rate)
|
||||
raise ValueError(f"Invalid side: {side}")
|
||||
|
||||
def _fee(self, notional: float) -> float:
|
||||
return abs(notional) * self.commission_rate
|
||||
|
||||
def place_market_order(
|
||||
self,
|
||||
*,
|
||||
position_id: str,
|
||||
symbol: str,
|
||||
side: str,
|
||||
qty: float,
|
||||
price: float,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
allow_partial: bool = False,
|
||||
) -> PaperTrade:
|
||||
side = side.upper().strip()
|
||||
if qty <= 0:
|
||||
raise ValueError("qty must be > 0")
|
||||
if price <= 0:
|
||||
raise ValueError("price must be > 0")
|
||||
|
||||
meta = meta or {}
|
||||
|
||||
exec_price = self._apply_slippage(side, float(price))
|
||||
notional = exec_price * float(qty)
|
||||
fee = self._fee(notional)
|
||||
|
||||
pos = self.get_position(position_id, symbol=symbol)
|
||||
realized = 0.0
|
||||
|
||||
if side == "BUY":
|
||||
total_cost = notional + fee
|
||||
if total_cost > self.cash:
|
||||
if not allow_partial:
|
||||
raise ValueError(
|
||||
f"Not enough cash for BUY. Need {total_cost:.2f}, have {self.cash:.2f}"
|
||||
)
|
||||
# qty máxima por cash (aprox)
|
||||
max_qty = max((self.cash / (exec_price * (1.0 + self.commission_rate))), 0.0)
|
||||
if max_qty <= 0:
|
||||
raise ValueError("Not enough cash to buy even a minimal quantity.")
|
||||
qty = float(max_qty)
|
||||
notional = exec_price * qty
|
||||
fee = self._fee(notional)
|
||||
|
||||
new_qty = pos.qty + qty
|
||||
if pos.qty == 0:
|
||||
new_avg = exec_price
|
||||
else:
|
||||
new_avg = (pos.avg_entry * pos.qty + exec_price * qty) / new_qty
|
||||
|
||||
pos.qty = new_qty
|
||||
pos.avg_entry = new_avg
|
||||
|
||||
self.cash -= (notional + fee)
|
||||
|
||||
elif side == "SELL":
|
||||
if qty > pos.qty:
|
||||
raise ValueError(f"Cannot SELL more than position qty. Have {pos.qty}, want {qty}")
|
||||
|
||||
realized = (exec_price - pos.avg_entry) * qty
|
||||
self.realized_pnl += realized
|
||||
|
||||
pos.qty -= qty
|
||||
if pos.qty <= 0:
|
||||
pos.qty = 0.0
|
||||
pos.avg_entry = 0.0
|
||||
|
||||
self.cash += (notional - fee)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid side: {side}")
|
||||
|
||||
trade = PaperTrade(
|
||||
symbol=symbol,
|
||||
position_id=position_id,
|
||||
side=side,
|
||||
qty=float(qty),
|
||||
price=float(exec_price),
|
||||
fee=float(fee),
|
||||
notional=float(notional),
|
||||
realized_pnl=float(realized),
|
||||
timestamp=self._now_iso(),
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
self.trades.append(trade)
|
||||
self.update_price(symbol, price)
|
||||
|
||||
return trade
|
||||
|
||||
# -----------------------------
|
||||
# Serialization
|
||||
# -----------------------------
|
||||
def snapshot(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"initial_cash": self.initial_cash,
|
||||
"cash": self.cash,
|
||||
"commission_rate": self.commission_rate,
|
||||
"slippage_rate": self.slippage_rate,
|
||||
"realized_pnl": self.realized_pnl,
|
||||
"equity": self.get_equity(),
|
||||
"positions": {pid: asdict(pos) for pid, pos in self.positions.items()},
|
||||
"last_price": dict(self.last_price),
|
||||
"trades_count": len(self.trades),
|
||||
"updated_at": self._now_iso(),
|
||||
}
|
||||
|
||||
def restore(self, state: Dict[str, Any]) -> None:
|
||||
self.initial_cash = float(state.get("initial_cash", self.initial_cash))
|
||||
self.cash = float(state.get("cash", self.cash))
|
||||
self.commission_rate = float(state.get("commission_rate", self.commission_rate))
|
||||
self.slippage_rate = float(state.get("slippage_rate", self.slippage_rate))
|
||||
self.realized_pnl = float(state.get("realized_pnl", self.realized_pnl))
|
||||
|
||||
self.last_price = {k: float(v) for k, v in (state.get("last_price") or {}).items()}
|
||||
|
||||
self.positions = {}
|
||||
for pid, p in (state.get("positions") or {}).items():
|
||||
self.positions[pid] = Position(
|
||||
symbol=p.get("symbol", ""), # ticker real
|
||||
position_id=p.get("position_id", pid),
|
||||
qty=float(p.get("qty", 0.0)),
|
||||
avg_entry=float(p.get("avg_entry", 0.0)),
|
||||
)
|
||||
|
||||
self.trades = []
|
||||
Reference in New Issue
Block a user