Karlo Pintaric
Upload 25 files
fdc1efd
import io
import sys
from pathlib import Path
import soundfile as sf
from fastapi.testclient import TestClient
sys.path.append(".")
from src.api.main import app # noqa
TEST_FILES_DIR = Path(__file__).parent / "test_files"
TEST_WAV_FILE = TEST_FILES_DIR / "test.wav"
client = TestClient(app)
def test_health_check():
response = client.get("/health-check")
assert response.status_code == 200
assert response.json() == {"status": "API is running"}
def test_predict_valid_cut_file():
audio_data, sample_rate = sf.read(TEST_WAV_FILE)
audio_file = io.BytesIO()
sf.write(audio_file, audio_data, sample_rate, format="wav")
audio_file = ("test.wav", audio_file)
file = {"file": audio_file}
request_data = {"model_name": "Accuracy"}
# Make a request to the /predict endpoint
response = client.post("/predict", params=request_data, files=file)
# Check that the response is successful
assert response.status_code == 200
assert response.json()["prediction"]["test.wav"] is not None
def test_predict_valid_file():
with open(TEST_WAV_FILE, "rb") as file:
data = {"model_name": "Accuracy"}
response = client.post("/predict", params=data, files={"file": file})
assert response.status_code == 200
assert response.json()["prediction"]["test.wav"] is not None
def test_predict_invalid_file_type():
file_data = io.BytesIO(b"dummy txt data")
file = ("test.txt", file_data)
data = {"model_name": "Accuracy"}
response = client.post("/predict", params=data, files={"file": file})
assert response.status_code == 400
assert "Only wav files are supported" in response.json()["detail"]
def test_predict_invalid_model():
file_data = io.BytesIO(b"dummy wav data")
file = ("test.wav", file_data)
data = {"model_name": "InvalidModel"}
response = client.post("/predict", params=data, files={"file": file})
assert response.status_code == 400
assert "Selected model doesn't exist" in response.json()["detail"]