mike23415 commited on
Commit
2dddabe
·
verified ·
1 Parent(s): 29c1018

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -13,6 +13,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
13
  os.environ['HF_HOME'] = '/app/.cache'
14
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
15
 
 
16
  app = Flask(__name__)
17
  CORS(app)
18
 
@@ -21,32 +22,32 @@ model_loaded = False
21
  load_error = None
22
  generator = None
23
 
24
- # --------------------------------------------------
25
- # Asynchronous Model Loading
26
- # --------------------------------------------------
27
  def load_model():
28
  global model_loaded, load_error, generator
29
  try:
30
- # Initialize model with low-memory settings
 
 
31
  model = AutoModelForCausalLM.from_pretrained(
32
  "gpt2-medium",
33
  use_safetensors=True,
34
  device_map="auto",
35
- low_cpu_mem_usage=True,
36
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
37
  )
38
 
39
  tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
40
 
 
41
  generator = pipeline(
42
  'text-generation',
43
  model=model,
44
  tokenizer=tokenizer,
45
- device=0 if torch.cuda.is_available() else -1
46
  )
47
 
48
  model_loaded = True
49
- print("Model loaded successfully")
50
 
51
  except Exception as e:
52
  load_error = str(e)
@@ -55,6 +56,8 @@ def load_model():
55
  # Start model loading in background thread
56
  Thread(target=load_model).start()
57
 
 
 
58
  # --------------------------------------------------
59
  # IEEE Format Template
60
  # --------------------------------------------------
 
13
  os.environ['HF_HOME'] = '/app/.cache'
14
  os.environ['XDG_CACHE_HOME'] = '/app/.cache'
15
 
16
+
17
  app = Flask(__name__)
18
  CORS(app)
19
 
 
22
  load_error = None
23
  generator = None
24
 
 
 
 
25
  def load_model():
26
  global model_loaded, load_error, generator
27
  try:
28
+ # Detect device and dtype automatically
29
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
+
31
  model = AutoModelForCausalLM.from_pretrained(
32
  "gpt2-medium",
33
  use_safetensors=True,
34
  device_map="auto",
35
+ torch_dtype=dtype,
36
+ low_cpu_mem_usage=True
37
  )
38
 
39
  tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
40
 
41
+ # Initialize pipeline without explicit device assignment
42
  generator = pipeline(
43
  'text-generation',
44
  model=model,
45
  tokenizer=tokenizer,
46
+ torch_dtype=dtype
47
  )
48
 
49
  model_loaded = True
50
+ print(f"Model loaded on {model.device}")
51
 
52
  except Exception as e:
53
  load_error = str(e)
 
56
  # Start model loading in background thread
57
  Thread(target=load_model).start()
58
 
59
+
60
+
61
  # --------------------------------------------------
62
  # IEEE Format Template
63
  # --------------------------------------------------