File size: 853 Bytes
52d185c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from typing import Dict, List, Any
from PIL import Image
import os
import json
import numpy as np
import keras
class PreTrainedPipeline():
def __init__(self, path=""):
self.model = keras.saving.load_model(os.path.join(path, "beans_disease_classification_transfer_learning.keras"))
with open(os.path.join(path, "config.json")) as config:
config = json.load(config)
self.id2label = config["id2label"]
def __call__(self, inputs: "Image.Image") -> List[Dict[str, Any]]:
preds = self.model.predict(np.array(inputs))
preds = preds.tolist()
labels = [
{"label": str(self.id2label["0"]), "score": preds[0]},
{"label": str(self.id2label["1"]), "score": preds[1]},
{"label": str(self.id2label["2"]), "score": preds[2]},
]
return labels |