MyDRMTools/MSTAR/FindKeyBox/find_keybox.py
SuperUserek a5788d6c31 Added Find Keybox python script and test_emmc.bin file
NOTE TEST_EMMC.bin is generated! DO NOT USE AS REAL KEYBOX
2026-02-23 18:10:12 +00:00

200 lines
5.4 KiB
Python

#!/usr/bin/env python3
import mmap
import os
import sys
import hashlib
from typing import List, Iterable, Tuple
PREFIX_HEX = "4D 53 54 41 52 5F 53 45 43 55 52 45 5F 53 54 4F 52 45 5F 46 49 4C 45 5F 4D 41 47 49 43 5F 49 44"
POSTFIX_HEX_LIST = [
"00 00",
]
EXTRACT_LEN = 228
OUT_BASE = "keybox.bin"
PRINT_HEXVIEW = True
HEXVIEW_WIDTH = 16
MAX_HITS = 0
REQUIRE_POSTFIX_AT_EXTRACT_END = True
MAX_ZERO_FRACTION = 0.25
MAX_ZERO_RUN = 16
TAIL_START = 0x90
MIN_TAIL_NONZERO_RATIO = 0.70
def clean_hex(s: str) -> bytes:
s = s.replace("0x", "").replace("0X", "")
s = "".join(s.split())
if not s:
return b""
if len(s) % 2 != 0:
raise ValueError("Hex string must have an even number of hex digits.")
return bytes.fromhex(s)
def iter_all(mm: mmap.mmap, needle: bytes, start: int = 0) -> Iterable[int]:
i = start
while True:
pos = mm.find(needle, i)
if pos == -1:
return
yield pos
i = pos + 1
def safe_out_name(base: str, idx: int) -> str:
root, ext = os.path.splitext(base)
if idx == 1:
return base
return f"{root}_{idx-1}{ext}"
def is_printable_ascii(b: int) -> bool:
return 32 <= b <= 126
def hexview(data: bytes, base_offset: int = 0, width: int = 16) -> str:
lines = []
for i in range(0, len(data), width):
chunk = data[i:i + width]
hex_part = " ".join(f"{x:02X}" for x in chunk)
hex_part_padded = hex_part.ljust(width * 3 - 1)
ascii_part = "".join(chr(x) if is_printable_ascii(x) else "." for x in chunk)
lines.append(f"{base_offset + i:08X} {hex_part_padded} |{ascii_part}|")
return "\n".join(lines)
def best_postfix_window(postfixes: List[bytes]) -> int:
return max((len(p) for p in postfixes if p), default=0)
def find_postfix_after_extract(mm: mmap.mmap, postfixes: List[bytes], prefix_pos: int, file_size: int) -> Tuple[int, bytes]:
start = prefix_pos + EXTRACT_LEN
if start >= file_size:
return -1, b""
window = best_postfix_window(postfixes)
end = min(file_size, start + window)
for pf in postfixes:
if not pf:
continue
if start + len(pf) > file_size:
continue
pos = mm.find(pf, start, end)
if pos != -1:
return pos, pf
return -1, b""
def max_zero_run(b: bytes) -> int:
best = 0
cur = 0
for x in b:
if x == 0:
cur += 1
if cur > best:
best = cur
else:
cur = 0
return best
def zero_fraction(b: bytes) -> float:
if not b:
return 1.0
z = sum(1 for x in b if x == 0)
return z / len(b)
def nonzero_ratio(b: bytes) -> float:
if not b:
return 0.0
nz = sum(1 for x in b if x != 0)
return nz / len(b)
def passes_filters(block: bytes) -> bool:
if zero_fraction(block) > MAX_ZERO_FRACTION:
return False
if max_zero_run(block) > MAX_ZERO_RUN:
return False
tail = block[TAIL_START:] if TAIL_START < len(block) else b""
if tail and nonzero_ratio(tail) < MIN_TAIL_NONZERO_RATIO:
return False
return True
def main() -> int:
if len(sys.argv) != 2:
print("Usage: python find_keybox.py <file>", file=sys.stderr)
return 2
path = sys.argv[1]
if not os.path.isfile(path):
print("ERROR: file not found", file=sys.stderr)
return 2
try:
prefix = clean_hex(PREFIX_HEX)
postfixes = [clean_hex(x) for x in POSTFIX_HEX_LIST]
except ValueError as e:
print(f"ERROR: {e}", file=sys.stderr)
return 2
postfixes = [p for p in postfixes if p]
if not prefix or not postfixes or EXTRACT_LEN <= 0:
return 2
print(f"EXTRACT_LEN={EXTRACT_LEN} bytes (0x{EXTRACT_LEN:X})")
file_size = os.path.getsize(path)
saved = 0
seen = set()
with open(path, "rb") as f:
mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
try:
for ppos in iter_all(mm, prefix, 0):
if ppos + EXTRACT_LEN > file_size:
continue
qpos, qpf = find_postfix_after_extract(mm, postfixes, ppos, file_size)
if qpos == -1:
continue
if REQUIRE_POSTFIX_AT_EXTRACT_END and qpos != (ppos + EXTRACT_LEN):
continue
block = mm[ppos:ppos + EXTRACT_LEN]
if not passes_filters(block):
continue
h = hashlib.sha256(block).digest()
if h in seen:
continue
seen.add(h)
saved += 1
out_name = safe_out_name(OUT_BASE, saved)
with open(out_name, "wb") as out:
out.write(block)
print(f"[hit {saved}] wrote {len(block)} bytes -> {out_name} | prefix@0x{ppos:X} postfix@0x{qpos:X}")
if PRINT_HEXVIEW:
print(hexview(block, base_offset=ppos, width=HEXVIEW_WIDTH))
print()
if MAX_HITS and saved >= MAX_HITS:
break
finally:
mm.close()
return 0 if saved else 1
if __name__ == "__main__":
raise SystemExit(main())