File size: 4,981 Bytes
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
c3657ca
 
 
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c3657ca
bc12604
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import re
import os
from transformers import (BartTokenizerFast, 
                          TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response
from typing import List
from pydantic import BaseModel
import uvicorn
import json
import logging
import multiprocessing


os.environ['TF_USE_LEGACY_KERAS'] = "1"

SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300


def load_summarizer_models():
    summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
    summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
    summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
    logging.warning('Loaded summarizer models')
    return summ_tokenizer, summ_model


async def summ_preprocess(txt):
    txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard . 
    txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
    txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
    txt = txt.replace('PUBLISHED:', ' ')
    txt = txt.replace('UPDATED', ' ')
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = txt.replace(' : ', ' ')
    txt = txt.replace('(CNN)', ' ')
    txt = txt.replace('--', ' ')
    txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
    txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
    txt = re.sub(r'\n+',' ', txt)
    txt = " ".join(txt.split())
    return txt


async def summ_inference_tokenize(input_: list, n_tokens: int):
    tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
    return summ_tokenizer, tokenized_data    


async def summ_inference(txts: str):
    txts = [*map(await summ_preprocess, txts)]
    inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
    pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
    result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
    return result

# def scrape_multi_process(urls):
#     logging.warning('Entering get_news_multi_process() to extract new news articles')
#     '''
#     Get the data shape by parallely calculating lenght of each chunk and 
#     aggregating them to get lenght of complete training dataset
#     '''
#     pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
    
#     results = []
#     for url in urls:
#         f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
#         results.append(f) # appending result to results
        
#     scraped_texts = []
#     for f in results:
#         scraped_texts.append(f.get(timeout=120))
#     pool.close()
#     pool.join()
#     logging.warning('Exiting scrape_multi_process()')
#     return scraped_texts

def scrape_urls(urls):
    scraped_texts = []
    scrape_errors = []
    for url in urls:
        text, err = await scrape_text(url)
        scraped_texts.append(text)
        scrape_errors.append(err)
    return scraped_texts, scrape_errors

##### API #####
app = FastAPI()
summ_tokenizer, summ_model = load_summarizer_models()

class URLList(BaseModel):
    urls: List[str]
    key: str


class NewsSummarizerAPIAuthenticationError(Exception):
    pass 


def authenticate_key(api_key: str):
    if api_key != os.getenv('API_KEY'):
        raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")

@app.post("/generate_summary/")
async def read_items(q: URLList):
    try:
        urls = ""
        scraped_texts = ""
        scrape_errors = ""
        summaries = ""
        request_json = q.json()
        request_json = json.loads(request_json)
        urls = request_json['urls']
        api_key = request_json['key']
        _ = authenticate_key(api_key)
        scraped_texts, scrape_errors = scrape_urls(urls)
        summaries = await summ_inference(scraped_texts)
        status_code = 200
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
    except Exception as e:
        status_code = 500
        if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
            status_code = 401
        response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}

    json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
    return Response(content=json_str, media_type='application/json', status_code=status_code)


if __name__ == '__main__':
    uvicorn.run(app=app, host='0.0.0.0', port=7860)