🗝
about summary refs log tree commit diff
path: root/zoner.py
diff options
context:
space:
mode:
Diffstat (limited to 'zoner.py')
-rw-r--r--zoner.py226
1 files changed, 226 insertions, 0 deletions
diff --git a/zoner.py b/zoner.py
new file mode 100644
index 0000000..0056396
--- /dev/null
+++ b/zoner.py
@@ -0,0 +1,226 @@
+from pathlib import Path
+import tomllib
+import os
+import requests
+import subprocess
+import sys
+
+
+cfg_dir = Path.home() / ".config/zoner"
+cfg_path = cfg_dir / "zoner.toml"
+if not cfg_path.exists():
+    print(f"config file at {cfg_path} not found")
+    exit(1)
+cfg = tomllib.loads(cfg_path.read_text())
+
+
+class Record:
+    kind: str
+    name: str
+    value: str
+    ttl: int
+    meta: dict
+
+    def __init__(self, kind: str, name: str, value: str, ttl: int, **kwargs):
+        self.kind = kind
+        self.name = name
+        self.content = value
+        self.ttl = ttl
+        self.meta = kwargs
+
+    def __eq__(self, other: object):
+        if isinstance(other, Record):
+            return (
+                self.kind == other.kind
+                and self.name == other.name
+                and self.content == other.content
+                and self.ttl == other.ttl
+            )
+        else:
+            raise TypeError(f"cannot compare {type(self)} with {type(other)}")
+
+    def __str__(self):
+        return f"{self.kind}\t{self.name}\t{self.content}\t$ttl {self.ttl}"
+
+
+def parse(filename: Path, domain: str):
+    text = filename.read_text()
+    while "\t\t" in text:
+        text = text.replace("\t\t", "\t")
+    rules = {"ttl": 3600}
+    for line in text.splitlines():
+        if line.startswith(";") or len(line) == 0:  # comments and empty lines
+            continue
+        if line.startswith("$"):  # rule line
+            update_rules(rules, line)
+            continue
+        # record line
+        line_rules = rules
+        segments = line.split("\t")
+        [name, kind, content, *opts] = segments
+        if len(opts) > 0:  # line-scope rules
+            line_rules = line_rules.copy()
+            for statement in opts:
+                update_rules(line_rules, statement)
+        if name == "@":
+            recname = domain
+        else:
+            recname = f"{name}.{domain}"
+        yield Record(kind, recname, content, line_rules["ttl"], apiname=name)
+
+
+def update_rules(rules: dict, statement: str):
+    [key, value] = statement.split(" ")
+    match key:
+        case "$ttl":
+            rules["ttl"] = int(value)
+        case _:
+            raise ValueError(f"invalid rule {key}")
+
+
+def check_porkbun(resp: requests.Response):
+    if resp.status_code >= 400:
+        raise requests.HTTPError(
+            f"got code {resp.status_code}: {resp.json()}", response=resp
+        )
+
+
+def retrieve(domain: str):
+    resp = requests.post(
+        f"https://porkbun.com/api/json/v3/dns/retrieve/{domain}",
+        json={"apikey": cfg["api_key"], "secretapikey": cfg["secret_key"]},
+    )
+    check_porkbun(resp)
+    for record in resp.json()["records"]:
+        if record["type"] == "NS":
+            continue
+        if record["type"] == "MX":
+            # add prioty to content
+            record["content"] = f"{record['prio']} {record['content']}"
+        if record["ttl"] == None:
+            record["ttl"] = 600
+        else:
+            record["ttl"] = int(record["ttl"])
+        yield Record(
+            record["type"],
+            record["name"],
+            record["content"],
+            record["ttl"],
+            id=record["id"],
+        )
+
+
+def resolve(domain: str):
+    existing = list(retrieve(domain))
+    requested = list(parse(cfg_dir / f"{domain}.zone", domain))
+
+    to_delete = []
+    to_create = []
+    for item in existing:
+        for check in requested:
+            if item == check:
+                break
+        else:
+            to_delete.append(item)
+    for item in requested:
+        for check in existing:
+            if item == check:
+                break
+        else:
+            to_create.append(item)
+    return (to_delete, to_create)
+
+
+def balance(domain: str):
+    to_delete, to_create = resolve(domain)
+
+    if len(to_delete) != 0:
+        print("to delete:")
+        for record in to_delete:
+            print(record)
+        print()
+
+    if len(to_create) != 0:
+        print("to create:")
+        for record in to_create:
+            print(record)
+        print()
+
+    if len(to_delete) == 0 and len(to_create) == 0:
+        print("nothing to do")
+        exit()
+
+    ch = input("continue? [y/N] ")
+    if ch != "y":
+        exit()
+
+    for record in to_delete:
+        print(f"delete {record}")
+        resp = requests.post(
+            f"https://porkbun.com/api/json/v3/dns/delete/{domain}/{record.meta['id']}",
+            json={"apikey": cfg["api_key"], "secretapikey": cfg["secret_key"]},
+        )
+        check_porkbun(resp)
+
+    for record in to_create:
+        print(f"create {record}")
+        content = record.content
+        prio = None
+        if record.kind == "MX":
+            prio, content = content.split(" ")
+        resp = requests.post(
+            f"https://porkbun.com/api/json/v3/dns/create/{domain}",
+            json={
+                "apikey": cfg["api_key"],
+                "secretapikey": cfg["secret_key"],
+                "name": record.meta["apiname"],
+                "type": record.kind,
+                "content": content,
+                "ttl": str(record.ttl),
+                "prio": prio,
+            },
+        )
+        check_porkbun(resp)
+
+    changes = set()
+    for entry in to_create:
+        changes.add(f"create {entry.kind} {entry.name}")
+    for entry in to_delete:
+        changes.add(f"delete {entry.kind} {entry.name}")
+    return list(changes)
+
+
+def git_update(domain: str):
+    subprocess.check_call(["git", "add", f"{domain}.zone"], cwd=cfg_dir)
+    if len(subprocess.check_output(["git", "diff", "--cached"], cwd=cfg_dir)) > 0:
+        subprocess.check_call(["git", "commit", "-m", f"update {domain}"], cwd=cfg_dir)
+        subprocess.check_call(["git", "push"], cwd=cfg_dir)
+
+
+def main():
+    if len(sys.argv) == 1:  # selector
+        names = map(lambda path: path.name[:-5], cfg_dir.glob("*.zone"))
+        ret = subprocess.run(
+            ["fzf"], input="\n".join(names).encode(), stdout=subprocess.PIPE
+        )
+        if ret.returncode != 0:
+            return
+        domain = ret.stdout.decode().strip()
+    else:
+        domain = sys.argv[1]
+    subprocess.run([os.environ["EDITOR"], (cfg_dir / f"{domain}.zone").as_posix()])
+    changes = balance(domain)
+    if (cfg_dir / ".git").exists():
+        git_update(domain)
+    if (cfg_dir / "hook").exists():
+        args = []
+        subprocess.check_call(
+            [
+                (cfg_dir / "hook").as_posix(),
+                *changes,
+            ]
+        )
+
+
+if __name__ == "__main__":
+    main()