svystun-taras's picture
created the updated web ui
0fdb130
import copy
import os
import tempfile
import types
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import torch
from huggingface_hub.utils import SoftTemporaryDirectory
from setfit.utils import set_docstring
from .. import logging
from ..modeling import SetFitModel
from .aspect_extractor import AspectExtractor
if TYPE_CHECKING:
from spacy.tokens import Doc
logger = logging.get_logger(__name__)
@dataclass
class SpanSetFitModel(SetFitModel):
spacy_model: str = "en_core_web_lg"
span_context: int = 0
attributes_to_save: Set[str] = field(
init=False,
repr=False,
default_factory=lambda: {"normalize_embeddings", "labels", "span_context", "spacy_model"},
)
def prepend_aspects(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[str]:
for doc, aspects in zip(docs, aspects_list):
for aspect_slice in aspects:
aspect = doc[max(aspect_slice.start - self.span_context, 0) : aspect_slice.stop + self.span_context]
# TODO: Investigate performance difference of different formats
yield aspect.text + ":" + doc.text
def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
inputs_list = list(self.prepend_aspects(docs, aspects_list))
preds = self.predict(inputs_list, as_numpy=True)
iter_preds = iter(preds)
return [[next(iter_preds) for _ in aspects] for aspects in aspects_list]
def create_model_card(self, path: str, model_name: Optional[str] = None) -> None:
"""Creates and saves a model card for a SetFit model.
Args:
path (str): The path to save the model card to.
model_name (str, *optional*): The name of the model. Defaults to `SetFit Model`.
"""
if not os.path.exists(path):
os.makedirs(path)
# If the model_path is a folder that exists locally, i.e. when create_model_card is called
# via push_to_hub, and the path is in a temporary folder, then we only take the last two
# directories
model_path = Path(model_name)
if model_path.exists() and Path(tempfile.gettempdir()) in model_path.resolve().parents:
model_name = "/".join(model_path.parts[-2:])
is_aspect = isinstance(self, AspectModel)
aspect_model = "setfit-absa-aspect"
polarity_model = "setfit-absa-polarity"
if model_name is not None:
if is_aspect:
aspect_model = model_name
if model_name.endswith("-aspect"):
polarity_model = model_name[: -len("-aspect")] + "-polarity"
else:
polarity_model = model_name
if model_name.endswith("-polarity"):
aspect_model = model_name[: -len("-polarity")] + "-aspect"
# Only once:
if self.model_card_data.absa is None and self.model_card_data.model_name:
from spacy import __version__ as spacy_version
self.model_card_data.model_name = self.model_card_data.model_name.replace(
"SetFit", "SetFit Aspect Model" if is_aspect else "SetFit Polarity Model", 1
)
self.model_card_data.tags.insert(1, "absa")
self.model_card_data.version["spacy"] = spacy_version
self.model_card_data.absa = {
"is_absa": True,
"is_aspect": is_aspect,
"spacy_model": self.spacy_model,
"aspect_model": aspect_model,
"polarity_model": polarity_model,
}
if self.model_card_data.task_name is None:
self.model_card_data.task_name = "Aspect Based Sentiment Analysis (ABSA)"
self.model_card_data.inference = False
with open(os.path.join(path, "README.md"), "w", encoding="utf-8") as f:
f.write(self.generate_model_card())
docstring = SpanSetFitModel.from_pretrained.__doc__
cut_index = docstring.find("multi_target_strategy")
if cut_index != -1:
docstring = (
docstring[:cut_index]
+ """model_card_data (`SetFitModelCardData`, *optional*):
A `SetFitModelCardData` instance storing data such as model language, license, dataset name,
etc. to be used in the automatically generated model cards.
use_differentiable_head (`bool`, *optional*):
Whether to load SetFit using a differentiable (i.e., Torch) head instead of Logistic Regression.
normalize_embeddings (`bool`, *optional*):
Whether to apply normalization on the embeddings produced by the Sentence Transformer body.
span_context (`int`, defaults to `0`):
The number of words before and after the span candidate that should be prepended to the full sentence.
By default, 0 for Aspect models and 3 for Polarity models.
device (`Union[torch.device, str]`, *optional*):
The device on which to load the SetFit model, e.g. `"cuda:0"`, `"mps"` or `torch.device("cuda")`."""
)
SpanSetFitModel.from_pretrained = set_docstring(SpanSetFitModel.from_pretrained, docstring, cls=SpanSetFitModel)
class AspectModel(SpanSetFitModel):
def __call__(self, docs: List["Doc"], aspects_list: List[List[slice]]) -> List[bool]:
sentence_preds = super().__call__(docs, aspects_list)
return [
[aspect for aspect, pred in zip(aspects, preds) if pred == "aspect"]
for aspects, preds in zip(aspects_list, sentence_preds)
]
# The set_docstring magic has as a consequences that subclasses need to update the cls in the from_pretrained
# classmethod, otherwise the wrong instance will be instantiated.
AspectModel.from_pretrained = types.MethodType(AspectModel.from_pretrained.__func__, AspectModel)
@dataclass
class PolarityModel(SpanSetFitModel):
span_context: int = 3
PolarityModel.from_pretrained = types.MethodType(PolarityModel.from_pretrained.__func__, PolarityModel)
@dataclass
class AbsaModel:
aspect_extractor: AspectExtractor
aspect_model: AspectModel
polarity_model: PolarityModel
def predict(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
is_str = isinstance(inputs, str)
inputs_list = [inputs] if is_str else inputs
docs, aspects_list = self.aspect_extractor(inputs_list)
if sum(aspects_list, []) == []:
return aspects_list
aspects_list = self.aspect_model(docs, aspects_list)
if sum(aspects_list, []) == []:
return aspects_list
polarity_list = self.polarity_model(docs, aspects_list)
outputs = []
for docs, aspects, polarities in zip(docs, aspects_list, polarity_list):
outputs.append(
[
{"span": docs[aspect_slice].text, "polarity": polarity}
for aspect_slice, polarity in zip(aspects, polarities)
]
)
return outputs if not is_str else outputs[0]
@property
def device(self) -> torch.device:
return self.aspect_model.device
def to(self, device: Union[str, torch.device]) -> "AbsaModel":
self.aspect_model.to(device)
self.polarity_model.to(device)
def __call__(self, inputs: Union[str, List[str]]) -> List[Dict[str, Any]]:
return self.predict(inputs)
def save_pretrained(
self,
save_directory: Union[str, Path],
polarity_save_directory: Optional[Union[str, Path]] = None,
push_to_hub: bool = False,
**kwargs,
) -> None:
if polarity_save_directory is None:
base_save_directory = Path(save_directory)
save_directory = base_save_directory.parent / (base_save_directory.name + "-aspect")
polarity_save_directory = base_save_directory.parent / (base_save_directory.name + "-polarity")
self.aspect_model.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
self.polarity_model.save_pretrained(save_directory=polarity_save_directory, push_to_hub=push_to_hub, **kwargs)
@classmethod
def from_pretrained(
cls,
model_id: str,
polarity_model_id: Optional[str] = None,
spacy_model: Optional[str] = None,
span_contexts: Tuple[Optional[int], Optional[int]] = (None, None),
force_download: bool = None,
resume_download: bool = None,
proxies: Optional[Dict] = None,
token: Optional[Union[str, bool]] = None,
cache_dir: Optional[str] = None,
local_files_only: bool = None,
use_differentiable_head: bool = None,
normalize_embeddings: bool = None,
**model_kwargs,
) -> "AbsaModel":
revision = None
if len(model_id.split("@")) == 2:
model_id, revision = model_id.split("@")
if spacy_model:
model_kwargs["spacy_model"] = spacy_model
aspect_model = AspectModel.from_pretrained(
model_id,
span_context=span_contexts[0],
revision=revision,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
use_differentiable_head=use_differentiable_head,
normalize_embeddings=normalize_embeddings,
labels=["no aspect", "aspect"],
**model_kwargs,
)
if polarity_model_id:
model_id = polarity_model_id
revision = None
if len(model_id.split("@")) == 2:
model_id, revision = model_id.split("@")
# If model_card_data was provided, "separate" the instance between the Aspect
# and Polarity models.
model_card_data = model_kwargs.pop("model_card_data", None)
if model_card_data:
model_kwargs["model_card_data"] = copy.deepcopy(model_card_data)
polarity_model = PolarityModel.from_pretrained(
model_id,
span_context=span_contexts[1],
revision=revision,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
use_differentiable_head=use_differentiable_head,
normalize_embeddings=normalize_embeddings,
**model_kwargs,
)
if aspect_model.spacy_model != polarity_model.spacy_model:
logger.warning(
"The Aspect and Polarity models are configured to use different spaCy models:\n"
f"* {repr(aspect_model.spacy_model)} for the aspect model, and\n"
f"* {repr(polarity_model.spacy_model)} for the polarity model.\n"
f"This model will use {repr(aspect_model.spacy_model)}."
)
aspect_extractor = AspectExtractor(spacy_model=aspect_model.spacy_model)
return cls(aspect_extractor, aspect_model, polarity_model)
def push_to_hub(self, repo_id: str, polarity_repo_id: Optional[str] = None, **kwargs) -> None:
if "/" not in repo_id:
raise ValueError(
'`repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
)
if polarity_repo_id is not None and "/" not in polarity_repo_id:
raise ValueError(
'`polarity_repo_id` must be a full repository ID, including organisation, e.g. "tomaarsen/setfit-absa-restaurant".'
)
commit_message = kwargs.pop("commit_message", "Add SetFit ABSA model")
# Push the files to the repo in a single commit
with SoftTemporaryDirectory() as tmp_dir:
save_directory = Path(tmp_dir) / repo_id
polarity_save_directory = None if polarity_repo_id is None else Path(tmp_dir) / polarity_repo_id
self.save_pretrained(
save_directory=save_directory,
polarity_save_directory=polarity_save_directory,
push_to_hub=True,
commit_message=commit_message,
**kwargs,
)