File size: 4,197 Bytes
fb2eca4
a5fb092
fb2eca4
0f6dd16
fb2eca4
 
 
 
ab29dc4
fb2eca4
ba788e4
671a925
bddbfe8
671a925
 
 
 
 
a96d506
671a925
 
 
bddbfe8
 
 
 
 
 
 
03b7efe
 
fb2eca4
 
 
 
 
63d94aa
 
 
fb2eca4
f9e997a
0f6dd16
d7ec351
 
fb2eca4
 
ca3380e
eef203f
671a925
 
 
 
 
 
 
fb2eca4
 
 
2005977
63d94aa
2005977
 
 
0abcb94
f9e997a
0abcb94
 
 
fb2eca4
 
 
 
 
 
 
 
 
 
 
2fbdede
d38295f
 
 
 
fb2eca4
 
 
10cb7da
 
 
 
 
fb2eca4
d6fcf5e
03b7efe
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import io
import re
import time
import os
from typing import List, Literal
from fastapi import FastAPI
from pydantic import BaseModel
from enum import Enum
from transformers import M2M100Tokenizer, M2M100ForConditionalGeneration
import torch
import uvicorn
from fastapi.responses import HTMLResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from enum import Enum
from fastapi.staticfiles import StaticFiles


os.makedirs("static", exist_ok=True)
app = FastAPI(docs_url="/docs", redoc_url=None)

app.mount("/static", StaticFiles(directory="static"), name="static")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class TranslationRequest(BaseModel):
    user_input: str
    source_lang: str
    target_lang: str
    model:str = 'facebook/m2m100_418M'
# facebook/m2m100_418M
# facebook/m2m100_1.2B

def load_model(model: str = 'facebook/m2m100_418M' , cache_dir: str = "models/"):
    model_dir = os.path.join(os.getcwd(), cache_dir)
    tokenizer = M2M100Tokenizer.from_pretrained(model, cache_dir=model_dir)
    model = M2M100ForConditionalGeneration.from_pretrained(model, cache_dir=model_dir).to(device)
    model.eval()
    return tokenizer, model
# aparentemente temos um problema ao carregar o modelo então vou tentar carregar no start da aplicação para não dar time-out na request
load_model()

@app.get("/", response_class=FileResponse)
async def read_index():
    """
    Serve o arquivo index.html da pasta static
    """
    return FileResponse("static/index.html")

@app.post("/translate")
async def translate(request: TranslationRequest):
    """
    models: facebook/m2m100_418M | facebook/m2m100_1.2B
    language support
    Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)
    """
    try:
        tokenizer, model = load_model(model=request.model)
    except Exception as E:
        return{"error": str(E)}
    
    src_lang = request.source_lang
    trg_lang = request.target_lang
    tokenizer.src_lang = src_lang
    with torch.no_grad():
        encoded_input = tokenizer(request.user_input, return_tensors="pt").to(device)
        generated_tokens = model.generate(
            **encoded_input, forced_bos_token_id=tokenizer.get_lang_id(trg_lang)
        )
        translated_text = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )[0]
   
    try:
        response = {"translation": translated_text}
    except Exception as E:
        return {"error": str(E)}
    return response



 



if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)