nyasukun commited on
Commit
18eee30
·
1 Parent(s): 4e1363d
Files changed (1) hide show
  1. app.py +56 -35
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
- from transformers import pipeline, AutoModelForCausalLM
4
  import torch
5
  import logging
6
 
@@ -11,44 +11,55 @@ logging.basicConfig(
11
  )
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
14
  # Predefined list of models to compare (can be expanded)
15
  model_options = {
16
- "Foundation-Sec-8B": pipeline("text-generation", model="fdtn-ai/Foundation-Sec-8B", torch_dtype=torch.bfloat16),
17
  }
18
 
19
- @spaces.GPU
20
- def generate_text_local(model_pipeline, prompt):
21
- """Local text generation"""
22
  try:
23
- # モデル名取得(なければ 'unknown')
24
- model_name = getattr(getattr(model_pipeline, "model", None), "name_or_path", "unknown")
25
- logger.info(f"Running local text generation with {model_name}")
26
-
27
- # Move model to GPU (entire pipeline)
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- if hasattr(model_pipeline, "model"):
30
- model_pipeline.model = model_pipeline.model.to(device)
31
-
32
- # Record device information
33
- device_info = next(model_pipeline.model.parameters()).device if hasattr(model_pipeline, "model") else "unknown"
34
- logger.info(f"Model {model_name} is running on device: {device_info}")
35
-
36
- outputs = model_pipeline(
37
- prompt,
38
- max_new_tokens=3,
39
- do_sample=True,
40
- temperature=0.1,
41
- top_p=0.9,
42
- clean_up_tokenization_spaces=True,
43
  )
 
 
 
44
 
45
- # Move model back to CPU
46
- if hasattr(model_pipeline, "model"):
47
- model_pipeline.model = model_pipeline.model.to("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- return outputs[0]["generated_text"].replace(prompt, "").strip()
 
 
50
  except Exception as e:
51
- logger.error(f"Error in local text generation with {model_name}: {str(e)}")
52
  return f"Error: {str(e)}"
53
 
54
  # Build Gradio app
@@ -94,13 +105,23 @@ def create_demo():
94
  ):
95
  #if len(selected_models) != 2:
96
  # return "Error: Please select exactly two models to compare.", ""
97
- responses = generate_text_local(
98
- #message, [], system_message, max_tokens, temperature, top_p, selected_models
99
- model_options[selected_models[0]],
100
- message
 
 
 
 
 
 
 
 
101
  )
 
102
  #return responses.get(selected_models[0], ""), responses.get(selected_models[1], "")
103
- return responses
 
104
  # Add a button for generating responses
105
  submit_button = gr.Button("Generate Responses")
106
  submit_button.click(
 
1
  import gradio as gr
2
  import spaces
3
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import logging
6
 
 
11
  )
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Stores for models and tokenizers
15
+ tokenizers = {}
16
+ pipelines = {}
17
+
18
  # Predefined list of models to compare (can be expanded)
19
  model_options = {
20
+ "Foundation-Sec-8B": "fdtn-ai/Foundation-Sec-8B",
21
  }
22
 
23
+ # Initialize models at startup
24
+ for model_name, model_path in model_options.items():
 
25
  try:
26
+ logger.info(f"Initializing text generation model: {model_path}")
27
+ tokenizers[model_path] = AutoTokenizer.from_pretrained(model_path)
28
+ pipelines[model_path] = pipeline(
29
+ "text-generation",
30
+ model=model_path,
31
+ tokenizer=tokenizers[model_path],
32
+ torch_dtype=torch.bfloat16,
33
+ device_map="auto",
34
+ trust_remote_code=True
 
 
 
 
 
 
 
 
 
 
 
35
  )
36
+ logger.info(f"Model initialized successfully: {model_path}")
37
+ except Exception as e:
38
+ logger.error(f"Error initializing model {model_path}: {str(e)}")
39
 
40
+ @spaces.GPU
41
+ def generate_text_local(model_path, prompt, max_new_tokens=512, temperature=0.7, top_p=0.95):
42
+ """Local text generation"""
43
+ try:
44
+ # Use the already initialized model
45
+ if model_path in pipelines:
46
+ model_pipeline = pipelines[model_path]
47
+ logger.info(f"Running text generation with {model_path}")
48
+
49
+ outputs = model_pipeline(
50
+ prompt,
51
+ max_new_tokens=max_new_tokens,
52
+ do_sample=True,
53
+ temperature=temperature,
54
+ top_p=top_p,
55
+ clean_up_tokenization_spaces=True,
56
+ )
57
 
58
+ return outputs[0]["generated_text"].replace(prompt, "").strip()
59
+ else:
60
+ return f"Error: Model {model_path} not initialized"
61
  except Exception as e:
62
+ logger.error(f"Error in text generation with {model_path}: {str(e)}")
63
  return f"Error: {str(e)}"
64
 
65
  # Build Gradio app
 
105
  ):
106
  #if len(selected_models) != 2:
107
  # return "Error: Please select exactly two models to compare.", ""
108
+
109
+ if len(selected_models) == 0:
110
+ return "Error: Please select at least one model"
111
+
112
+ model_path = model_options[selected_models[0]]
113
+ full_prompt = f"{system_message}\n\nUser: {message}\nAssistant:"
114
+ response = generate_text_local(
115
+ model_path,
116
+ full_prompt,
117
+ max_tokens,
118
+ temperature,
119
+ top_p
120
  )
121
+
122
  #return responses.get(selected_models[0], ""), responses.get(selected_models[1], "")
123
+ return response
124
+
125
  # Add a button for generating responses
126
  submit_button = gr.Button("Generate Responses")
127
  submit_button.click(