Spaces:
Running
Running
File size: 4,197 Bytes
fb2eca4 a5fb092 fb2eca4 0f6dd16 fb2eca4 ab29dc4 fb2eca4 ba788e4 671a925 bddbfe8 671a925 a96d506 671a925 bddbfe8 03b7efe fb2eca4 63d94aa fb2eca4 f9e997a 0f6dd16 d7ec351 fb2eca4 ca3380e eef203f 671a925 fb2eca4 2005977 63d94aa 2005977 0abcb94 f9e997a 0abcb94 fb2eca4 2fbdede d38295f fb2eca4 10cb7da fb2eca4 d6fcf5e 03b7efe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import io
import re
import time
import os
from typing import List, Literal
from fastapi import FastAPI
from pydantic import BaseModel
from enum import Enum
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import torch
import uvicorn
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from enum import Enum
from fastapi.staticfiles import StaticFiles
os.makedirs("static", exist_ok=True)
app = FastAPI(docs_url="/docs", redoc_url=None)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class TranslationRequest(BaseModel):
user_input: str
source_lang: str
target_lang: str
model:str = 'facebook/m2m100_418M'
# facebook/m2m100_418M
# facebook/m2m100_1.2B
def load_model(model: str = 'facebook/m2m100_418M' , cache_dir: str = "models/"):
model_dir = os.path.join(os.getcwd(), cache_dir)
tokenizer = M2M100Tokenizer.from_pretrained(model, cache_dir=model_dir)
model = M2M100ForConditionalGeneration.from_pretrained(model, cache_dir=model_dir).to(device)
model.eval()
return tokenizer, model
# aparentemente temos um problema ao carregar o modelo então vou tentar carregar no start da aplicação para não dar time-out na request
load_model()
@app.get("/", response_class=FileResponse)
async def read_index():
"""
Serve o arquivo index.html da pasta static
"""
return FileResponse("static/index.html")
@app.post("/translate")
async def translate(request: TranslationRequest):
"""
models: facebook/m2m100_418M | facebook/m2m100_1.2B
language support
Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
"""
try:
tokenizer, model = load_model(model=request.model)
except Exception as E:
return{"error": str(E)}
src_lang = request.source_lang
trg_lang = request.target_lang
tokenizer.src_lang = src_lang
with torch.no_grad():
encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
)
translated_text = tokenizer.batch_decode(
generated_tokens, skip_special_tokens=True
)[0]
try:
response = {"translation": translated_text}
except Exception as E:
return {"error": str(E)}
return response
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)
|