Filter checkpoints into a Hugging Face model repo
Very specific; not very useful
The HfApi
interface doesn't give enough git control, as far as I can tell: each commit is immediate, which runs afoul of rate limits.
from pathlib import Path
BASE="/tmp/foobar"
extension = "pth"
best_name = "best_model"
best_tpl = "best_model_"
ckpt_tpl = "checkpoint_"
repo_id = "jimregan/ckpt_test"
BASE_PATH = Path(BASE)
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=repo_id)
#api.run_as_future(api.create_repo, repo_id)
def get_ckpt_num(ckpt_path: Path, best=False):
if best:
tpl = best_tpl
else:
tpl = ckpt_tpl
stem = ckpt_path.stem
if not stem.startswith(tpl):
return None
return int(stem.replace(tpl, ""))
ckpts_by_numbers = {get_ckpt_num(x): x for x in BASE_PATH.glob(f'{ckpt_tpl}*.{extension}')}
ckpts_ordered = sorted(ckpts_by_numbers.keys())
best_ordered = sorted([get_ckpt_num(x, True) for x in BASE_PATH.glob(f'{best_tpl}*.{extension}')])
best_ordered
def find_next_ckpt(best):
for ckpt in ckpts_ordered:
if ckpt < best:
continue
else:
return ckpt
ckpt_to_best = {find_next_ckpt(b): b for b in best_ordered}
ckpt_to_best
from huggingface_hub import CommitOperationAdd
ops = []
tfevents = []
for filepath in BASE_PATH.glob("*"):
if filepath.stem.startswith(ckpt_tpl) or filepath.stem.startswith(best_name):
continue
elif "tfevents" in filepath.name:
ops.append(CommitOperationAdd(path_in_repo=f"runs/{filepath.name}", path_or_fileobj=str(filepath)))
else:
ops.append(CommitOperationAdd(path_in_repo=filepath.name, path_or_fileobj=str(filepath)))
api.create_commit(
repo_id=repo_id,
operations=ops,
commit_message="Initial commit of files other than checkpoints"
)
if tfevents != []:
api.create_commit(
repo_id=repo_id,
operations=tfevents,
commit_message="Adding tfevents"
)
for ckpt in ckpts_ordered:
if ckpt in ckpt_to_best:
filepath = BASE_PATH / f"{best_tpl}{ckpt_to_best[ckpt]}.{extension}"
api.create_commit(
repo_id=repo_id,
operations=[CommitOperationAdd(path_in_repo=f"{best_name}.{extension}", path_or_fileobj=str(filepath))],
commit_message=f"Best model: {ckpt_to_best[ckpt]}",
run_as_future=True
)
filepath = BASE_PATH / f"{ckpt_tpl}{ckpt}.{extension}"
api.create_commit(
repo_id=repo_id,
operations=[CommitOperationAdd(path_in_repo=f"checkpoint.{extension}", path_or_fileobj=str(filepath))],
commit_message=f"Checkpoint: {ckpt}",
run_as_future=True
)