diff --git a/kischdle/llmux/llmux/auth.py b/kischdle/llmux/llmux/auth.py new file mode 100644 index 0000000..ac6a0e5 --- /dev/null +++ b/kischdle/llmux/llmux/auth.py @@ -0,0 +1,19 @@ +from fastapi import HTTPException, Request + +from llmux.config import ApiKey + + +def create_api_key_dependency(api_keys: list[ApiKey]): + key_to_name = {k.key: k.name for k in api_keys} + + async def require_api_key(request: Request) -> str: + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + raise HTTPException(status_code=401, detail="Missing or malformed Authorization header") + token = auth[7:] + name = key_to_name.get(token) + if name is None: + raise HTTPException(status_code=401, detail="Invalid API key") + return name + + return require_api_key diff --git a/kischdle/llmux/tests/test_auth.py b/kischdle/llmux/tests/test_auth.py new file mode 100644 index 0000000..5a7cbd8 --- /dev/null +++ b/kischdle/llmux/tests/test_auth.py @@ -0,0 +1,55 @@ +import pytest +from fastapi import FastAPI, Depends +from fastapi.testclient import TestClient + +from llmux.auth import create_api_key_dependency +from llmux.config import ApiKey + + +@pytest.fixture +def app_with_auth(): + keys = [ + ApiKey(key="sk-test-valid-key", name="Test"), + ApiKey(key="sk-test-another-key", name="Another"), + ] + require_api_key = create_api_key_dependency(keys) + + app = FastAPI() + + @app.get("/protected") + def protected(api_key: str = Depends(require_api_key)): + return {"key_name": api_key} + + return app + + +@pytest.fixture +def client(app_with_auth): + return TestClient(app_with_auth) + + +def test_valid_key_returns_200(client): + resp = client.get("/protected", headers={"Authorization": "Bearer sk-test-valid-key"}) + assert resp.status_code == 200 + assert resp.json()["key_name"] == "Test" + + +def test_another_valid_key(client): + resp = client.get("/protected", headers={"Authorization": "Bearer sk-test-another-key"}) + assert resp.status_code == 200 + assert resp.json()["key_name"] == "Another" + + +def test_missing_auth_header_returns_401(client): + resp = client.get("/protected") + assert resp.status_code == 401 + + +def test_invalid_key_returns_401(client): + resp = client.get("/protected", headers={"Authorization": "Bearer sk-wrong"}) + assert resp.status_code == 401 + + +def test_malformed_header_returns_401(client): + resp = client.get("/protected", headers={"Authorization": "sk-test-valid-key"}) + assert resp.status_code == 401