Spaces:
Sleeping
Sleeping
File size: 4,981 Bytes
bc12604 c3657ca bc12604 c3657ca bc12604 c3657ca bc12604 c3657ca bc12604 c3657ca bc12604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import re
import os
from transformers import (BartTokenizerFast,
TFAutoModelForSeq2SeqLM)
import tensorflow as tf
from scraper import scrape_text
from fastapi import FastAPI, Response
from typing import List
from pydantic import BaseModel
import uvicorn
import json
import logging
import multiprocessing
os.environ['TF_USE_LEGACY_KERAS'] = "1"
SUMM_CHECKPOINT = "facebook/bart-base"
SUMM_INPUT_N_TOKENS = 400
SUMM_TARGET_N_TOKENS = 300
def load_summarizer_models():
summ_tokenizer = BartTokenizerFast.from_pretrained(SUMM_CHECKPOINT)
summ_model = TFAutoModelForSeq2SeqLM.from_pretrained(SUMM_CHECKPOINT)
summ_model.load_weights(os.path.join("models", "bart_en_summarizer.h5"), by_name=True)
logging.warning('Loaded summarizer models')
return summ_tokenizer, summ_model
async def summ_preprocess(txt):
txt = re.sub(r'^By \. [\w\s]+ \. ', ' ', txt) # By . Ellie Zolfagharifard .
txt = re.sub(r'\d{1,2}\:\d\d [a-zA-Z]{3}', ' ', txt) # 10:30 EST
txt = re.sub(r'\d{1,2} [a-zA-Z]+ \d{4}', ' ', txt) # 10 November 1990
txt = txt.replace('PUBLISHED:', ' ')
txt = txt.replace('UPDATED', ' ')
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = txt.replace(' : ', ' ')
txt = txt.replace('(CNN)', ' ')
txt = txt.replace('--', ' ')
txt = re.sub(r'^\s*[\,\.\:\'\;\|]', ' ', txt) # remove puncts at beginning of sent
txt = re.sub(r' [\,\.\:\'\;\|] ', ' ', txt) # remove puncts with spaces before and after
txt = re.sub(r'\n+',' ', txt)
txt = " ".join(txt.split())
return txt
async def summ_inference_tokenize(input_: list, n_tokens: int):
tokenized_data = summ_tokenizer(text=input_, max_length=SUMM_TARGET_N_TOKENS, truncation=True, padding="max_length", return_tensors="tf")
return summ_tokenizer, tokenized_data
async def summ_inference(txts: str):
txts = [*map(await summ_preprocess, txts)]
inference_tokenizer, tokenized_data = await summ_inference_tokenize(input_=txts, n_tokens=SUMM_INPUT_N_TOKENS)
pred = summ_model.generate(**tokenized_data, max_new_tokens=SUMM_TARGET_N_TOKENS)
result = ["" if t=="" else inference_tokenizer.decode(p, skip_special_tokens=True).strip() for t, p in zip(txts, pred)]
return result
# def scrape_multi_process(urls):
# logging.warning('Entering get_news_multi_process() to extract new news articles')
# '''
# Get the data shape by parallely calculating lenght of each chunk and
# aggregating them to get lenght of complete training dataset
# '''
# pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
# results = []
# for url in urls:
# f = pool.apply_async(scrape_text, [url]) # asynchronously applying function to chunk. Each worker parallely begins to work on the job
# results.append(f) # appending result to results
# scraped_texts = []
# for f in results:
# scraped_texts.append(f.get(timeout=120))
# pool.close()
# pool.join()
# logging.warning('Exiting scrape_multi_process()')
# return scraped_texts
def scrape_urls(urls):
scraped_texts = []
scrape_errors = []
for url in urls:
text, err = await scrape_text(url)
scraped_texts.append(text)
scrape_errors.append(err)
return scraped_texts, scrape_errors
##### API #####
app = FastAPI()
summ_tokenizer, summ_model = load_summarizer_models()
class URLList(BaseModel):
urls: List[str]
key: str
class NewsSummarizerAPIAuthenticationError(Exception):
pass
def authenticate_key(api_key: str):
if api_key != os.getenv('API_KEY'):
raise NewsSummarizerAPIAuthenticationError("Authentication error: Invalid API key.")
@app.post("/generate_summary/")
async def read_items(q: URLList):
try:
urls = ""
scraped_texts = ""
scrape_errors = ""
summaries = ""
request_json = q.json()
request_json = json.loads(request_json)
urls = request_json['urls']
api_key = request_json['key']
_ = authenticate_key(api_key)
scraped_texts, scrape_errors = scrape_urls(urls)
summaries = await summ_inference(scraped_texts)
status_code = 200
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': summaries, 'summarizer_error': ''}
except Exception as e:
status_code = 500
if e.__class__.__name__ == "NewsSummarizerAPIAuthenticationError":
status_code = 401
response_json = {'urls': urls, 'scraped_texts': scraped_texts, 'scrape_errors': scrape_errors, 'summaries': "", 'summarizer_error': f'error: {e}'}
json_str = json.dumps(response_json, indent=5) # convert dict to JSON str
return Response(content=json_str, media_type='application/json', status_code=status_code)
if __name__ == '__main__':
uvicorn.run(app=app, host='0.0.0.0', port=7860) |