Spaces:
Running
Running
from ast import Attribute | |
from dotenv import load_dotenv | |
load_dotenv(override=True) | |
import re | |
import os | |
import pandas as pd | |
import json | |
from typing import List, Dict, Any | |
import pandas as pd | |
import gradio as gr | |
import datetime | |
from pathlib import Path | |
import json | |
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine | |
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams | |
from risk_atlas_nexus.library import RiskAtlasNexus | |
from functools import lru_cache | |
# Load the taxonomies | |
ran = RiskAtlasNexus() # type: ignore | |
def risk_identifier(usecase: str, | |
model_name_or_path: str = "ibm/granite-20b-code-instruct", | |
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame: | |
inference_engine = WMLInferenceEngine( | |
model_name_or_path= model_name_or_path, | |
credentials={ | |
"api_key": os.environ["WML_API_KEY"], | |
"api_url": os.environ["WML_API_URL"], | |
"project_id": os.environ["WML_PROJECT_ID"], | |
}, | |
parameters=WMLInferenceEngineParams( | |
max_new_tokens=150, decoding_method="greedy", repetition_penalty=1 | |
), # type: ignore | |
) | |
risks = ran.identify_risks_from_usecases( # type: ignore | |
usecases=[usecase], | |
inference_engine=inference_engine, | |
taxonomy=taxonomy, | |
)[0] | |
sample_labels = [r.name if r else r.id for r in risks] | |
out_sec = gr.Markdown("""<h2> Potential Risks </h2> """) | |
# write out a JSON | |
data = {'time': str(datetime.datetime.now(datetime.timezone.utc)), | |
'intent': usecase, | |
'model': model_name_or_path, | |
'taxonomy': taxonomy, | |
'risks': [json.loads(r.json()) for r in risks] | |
} | |
file_path = Path("static/download.json") | |
file_path.write_text(json.dumps(data, indent=4), encoding='utf-8') | |
#return out_df | |
return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks], | |
sample_labels=sample_labels, | |
samples_per_page=50, visible=True, label="Estimated by an LLM."), gr.DownloadButton("Download JSON", visible=True, value="static/download.json") | |
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr.DataFrame, gr.Markdown]: | |
""" | |
For a specific risk (riskid), returns | |
(a) a risk description | |
(b) related risks - as a dataset | |
(c) mitigations | |
""" | |
try: | |
risk_desc = ran.get_risk(id=riskid).description # type: ignore | |
risk_sec = f"<h3>Description: </h3> {risk_desc}" | |
except AttributeError: | |
risk_sec = "" | |
related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)] | |
action_ids = [] | |
if taxonomy == "ibm-risk-atlas": | |
# look for actions associated with related risks | |
if related_risk_ids: | |
for i in related_risk_ids: | |
rai = ran.get_related_actions(id=i) | |
if rai: | |
action_ids += rai | |
else: | |
action_ids = [] | |
else: | |
# Use only actions related to primary risks | |
action_ids = ran.get_related_actions(id=riskid) | |
# Sanitize outputs | |
if not related_risk_ids: | |
label = "No related risks found." | |
samples = None | |
sample_labels = None | |
else: | |
label = f"Risks from other taxonomies related to {riskid}" | |
samples = related_risk_ids | |
sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore | |
if not action_ids: | |
alabel = "No mitigations found." | |
asamples = None | |
asample_labels = None | |
mitdf = pd.DataFrame() | |
else: | |
alabel = f"Mitigation actions related to risk {riskid}." | |
asamples = action_ids | |
asample_labels = [ran.get_action_by_id(i).description for i in asamples] # type: ignore | |
asample_name = [ran.get_action_by_id(i).name for i in asamples] #type: ignore | |
mitdf = pd.DataFrame({"Mitigation": asample_name, "Description": asample_labels}) | |
status = gr.Markdown(" ") if len(mitdf) > 0 else gr.Markdown("No mitigations found.") | |
return (gr.Markdown(risk_sec), | |
gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True), | |
gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True), | |
status) | |