#!/usr/bin/env python3
"""
fix-bundle-router-bindings.py   (BUNDLEFIX-005, part 2 / option A)

Adds `bindings: *internal-bindings` to every mysql-router application block in the
Caracal bundle, so the router subordinates bind to the metal space -- matching the
live `juju bind <router> metal` fix already applied to the running model.

Why: without an explicit binding the routers default to the empty 'alpha' space,
which resolves to the container's PROVIDER address. The cluster then grants
mysqlrouteruser@<provider-addr>, but the router's actual TCP connection to the
metal-only cluster egresses the metal interface -> grant host != source ->
"Access denied 1045" -> mysqlrouter never bootstraps. Binding to metal makes the
advertised address == the connection source.

Safe by construction:
  - pure line edits (NO YAML round-trip; preserves anchors, comments, formatting)
  - timestamped .bak
  - prints a unified diff
  - idempotent (skips any router that already carries a bindings line)
  - yaml.safe_load verification of the result, asserting every mysql-router app
    resolves to bindings {'': 'metal'} via the *internal-bindings anchor
  - aborts unless it finds the expected mysql-router blocks and they verify

Usage:
  python3 fix-bundle-router-bindings.py [path/to/bundle.yaml]   (default ./bundle.yaml)
"""
import sys, os, difflib, datetime

DEFAULT = "bundle.yaml"


def transform(lines):
    """Insert `<indent>bindings: *internal-bindings` after the channel line of
    every `charm: mysql-router` app block that doesn't already have a bindings line."""
    out = []
    prev_is_mr_charm = False
    found = inserted = skipped = 0
    for idx, line in enumerate(lines):
        out.append(line)
        stripped = line.strip()
        if prev_is_mr_charm and stripped.startswith("channel:"):
            found += 1
            nxt = lines[idx + 1].strip() if idx + 1 < len(lines) else ""
            if nxt.startswith("bindings:"):
                skipped += 1
            else:
                indent = line[: len(line) - len(line.lstrip())]
                out.append(f"{indent}bindings: *internal-bindings")
                inserted += 1
        prev_is_mr_charm = (stripped == "charm: mysql-router")
    return out, found, inserted, skipped


def main():
    path = sys.argv[1] if len(sys.argv) > 1 else DEFAULT
    if not os.path.isfile(path):
        print(f"[ABORT] not found: {path}")
        return 2

    with open(path, "r", encoding="utf-8") as f:
        original = f.read()
    lines = original.splitlines()

    out, found, inserted, skipped = transform(lines)
    new = "\n".join(out) + ("\n" if original.endswith("\n") else "")

    if found == 0:
        print("[ABORT] no `charm: mysql-router` + `channel:` blocks found - unexpected structure.")
        return 3
    if inserted == 0 and skipped == found:
        print(f"[OK/IDEMPOTENT] all {found} mysql-router apps already bound; no change.")
        return 0

    print("=== unified diff ===")
    diff = "\n".join(difflib.unified_diff(
        original.splitlines(), new.splitlines(),
        fromfile=f"{path} (orig)", tofile=f"{path} (new)", lineterm=""))
    print(diff or "(no diff)")
    print(f"=== mysql-router blocks: {found} | inserted: {inserted} | already-bound: {skipped} ===")

    # semantic verification (anchors resolve under safe_load)
    try:
        import yaml
        doc = yaml.safe_load(new)
        apps = (doc or {}).get("applications", {}) or {}
        mr = {k: v for k, v in apps.items()
              if isinstance(v, dict) and v.get("charm") == "mysql-router"}
        bad = {k: v.get("bindings") for k, v in mr.items() if v.get("bindings") != {"": "metal"}}
        if bad:
            print(f"[ABORT] verification failed; not bound to {{'': 'metal'}}: {bad}")
            return 4
        print(f"[VERIFY] yaml.safe_load OK; all {len(mr)} mysql-router apps -> bindings {{'': 'metal'}}.")
    except ImportError:
        print("[WARN] PyYAML missing; skipped semantic verify (re-verify on jumphost after pull).")
    except Exception as e:
        print(f"[ABORT] yaml verification error: {e}")
        return 5

    ts = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    bak = f"{path}.bak-{ts}"
    with open(bak, "w", encoding="utf-8") as f:
        f.write(original)
    with open(path, "w", encoding="utf-8") as f:
        f.write(new)
    print(f"[WROTE] {path}  (backup: {bak})")
    return 0


if __name__ == "__main__":
    sys.exit(main())
