# src/paper/broker.py """ Paper Broker (Spot, Long-only) con posiciones POR ESTRATEGIA. - Mantiene cash, equity, pnl - Posiciones indexadas por position_id (ej: "BTC/USDT::MA_Crossover") - Cada Position guarda el symbol REAL ("BTC/USDT") para MTM correcto """ from __future__ import annotations from dataclasses import dataclass, asdict from typing import Optional, Dict, Any, List from datetime import datetime, timezone @dataclass class PaperTrade: symbol: str # ticker real, e.g. "BTC/USDT" position_id: str # e.g. "BTC/USDT::MA_Crossover" side: str # "BUY" or "SELL" qty: float price: float # executed price (after slippage) fee: float notional: float realized_pnl: float timestamp: str # ISO8601 meta: Dict[str, Any] @dataclass class Position: symbol: str # ticker real position_id: str # clave lógica qty: float = 0.0 avg_entry: float = 0.0 def is_open(self) -> bool: return self.qty > 0.0 class PaperBroker: def __init__( self, initial_cash: float, commission_rate: float = 0.001, slippage_rate: float = 0.0005, ): if initial_cash <= 0: raise ValueError("initial_cash must be > 0") self.initial_cash = float(initial_cash) self.cash = float(initial_cash) self.commission_rate = float(commission_rate) self.slippage_rate = float(slippage_rate) # position_id -> Position self.positions: Dict[str, Position] = {} # symbol -> last price self.last_price: Dict[str, float] = {} self.realized_pnl: float = 0.0 self.trades: List[PaperTrade] = [] # ----------------------------- # Pricing / MTM # ----------------------------- def update_price(self, symbol: str, price: float) -> None: if price <= 0: raise ValueError("price must be > 0") self.last_price[symbol] = float(price) def get_position(self, position_id: str, symbol: Optional[str] = None) -> Position: """ Devuelve una Position por position_id. Si no existe, crea una nueva (requiere symbol para inicializar correctamente). """ if position_id not in self.positions: if symbol is None: raise ValueError("symbol is required to create a new position") self.positions[position_id] = Position(symbol=symbol, position_id=position_id) return self.positions[position_id] def get_unrealized_pnl(self, position_id: str) -> float: pos = self.positions.get(position_id) if not pos or not pos.is_open(): return 0.0 lp = self.last_price.get(pos.symbol) if lp is None: return 0.0 return (lp - pos.avg_entry) * pos.qty def get_equity(self) -> float: """ Equity = cash + sum(pos.qty * last_price[pos.symbol]) """ equity = self.cash for _, pos in self.positions.items(): if pos.is_open(): lp = self.last_price.get(pos.symbol) if lp is not None: equity += pos.qty * lp else: equity += pos.qty * pos.avg_entry return float(equity) # ----------------------------- # Execution # ----------------------------- def _now_iso(self) -> str: return datetime.now(timezone.utc).isoformat() def _apply_slippage(self, side: str, price: float) -> float: s = side.upper() if s == "BUY": return price * (1.0 + self.slippage_rate) elif s == "SELL": return price * (1.0 - self.slippage_rate) raise ValueError(f"Invalid side: {side}") def _fee(self, notional: float) -> float: return abs(notional) * self.commission_rate def place_market_order( self, *, position_id: str, symbol: str, side: str, qty: float, price: float, meta: Optional[Dict[str, Any]] = None, allow_partial: bool = False, ) -> PaperTrade: side = side.upper().strip() if qty <= 0: raise ValueError("qty must be > 0") if price <= 0: raise ValueError("price must be > 0") meta = meta or {} exec_price = self._apply_slippage(side, float(price)) notional = exec_price * float(qty) fee = self._fee(notional) pos = self.get_position(position_id, symbol=symbol) realized = 0.0 if side == "BUY": total_cost = notional + fee if total_cost > self.cash: if not allow_partial: raise ValueError( f"Not enough cash for BUY. Need {total_cost:.2f}, have {self.cash:.2f}" ) # qty máxima por cash (aprox) max_qty = max((self.cash / (exec_price * (1.0 + self.commission_rate))), 0.0) if max_qty <= 0: raise ValueError("Not enough cash to buy even a minimal quantity.") qty = float(max_qty) notional = exec_price * qty fee = self._fee(notional) new_qty = pos.qty + qty if pos.qty == 0: new_avg = exec_price else: new_avg = (pos.avg_entry * pos.qty + exec_price * qty) / new_qty pos.qty = new_qty pos.avg_entry = new_avg self.cash -= (notional + fee) elif side == "SELL": if qty > pos.qty: raise ValueError(f"Cannot SELL more than position qty. Have {pos.qty}, want {qty}") realized = (exec_price - pos.avg_entry) * qty self.realized_pnl += realized pos.qty -= qty if pos.qty <= 0: pos.qty = 0.0 pos.avg_entry = 0.0 self.cash += (notional - fee) else: raise ValueError(f"Invalid side: {side}") trade = PaperTrade( symbol=symbol, position_id=position_id, side=side, qty=float(qty), price=float(exec_price), fee=float(fee), notional=float(notional), realized_pnl=float(realized), timestamp=self._now_iso(), meta=meta, ) self.trades.append(trade) self.update_price(symbol, price) return trade # ----------------------------- # Serialization # ----------------------------- def snapshot(self) -> Dict[str, Any]: return { "initial_cash": self.initial_cash, "cash": self.cash, "commission_rate": self.commission_rate, "slippage_rate": self.slippage_rate, "realized_pnl": self.realized_pnl, "equity": self.get_equity(), "positions": {pid: asdict(pos) for pid, pos in self.positions.items()}, "last_price": dict(self.last_price), "trades_count": len(self.trades), "updated_at": self._now_iso(), } def restore(self, state: Dict[str, Any]) -> None: self.initial_cash = float(state.get("initial_cash", self.initial_cash)) self.cash = float(state.get("cash", self.cash)) self.commission_rate = float(state.get("commission_rate", self.commission_rate)) self.slippage_rate = float(state.get("slippage_rate", self.slippage_rate)) self.realized_pnl = float(state.get("realized_pnl", self.realized_pnl)) self.last_price = {k: float(v) for k, v in (state.get("last_price") or {}).items()} self.positions = {} for pid, p in (state.get("positions") or {}).items(): self.positions[pid] = Position( symbol=p.get("symbol", ""), # ticker real position_id=p.get("position_id", pid), qty=float(p.get("qty", 0.0)), avg_entry=float(p.get("avg_entry", 0.0)), ) self.trades = []