Bobholamovic commited on
Commit
6d6af66
·
1 Parent(s): 8b775e5

Bind thread with model

Browse files
Files changed (1) hide show
  1. app.py +60 -84
app.py CHANGED
@@ -1,11 +1,13 @@
1
- import asyncio
2
  import functools
3
- import uuid
 
4
 
5
  from paddleocr import PaddleOCR, draw_ocr
6
  from PIL import Image
7
  import gradio as gr
8
 
 
9
  LANG_CONFIG = {
10
  "ch": {"num_workers": 4},
11
  "en": {"num_workers": 4},
@@ -17,95 +19,53 @@ LANG_CONFIG = {
17
  CONCURRENCY_LIMIT = 8
18
 
19
 
20
- class PaddleOCRModelWrapper(object):
21
- def __init__(self, model, name=None):
22
- super().__init__()
23
- self._model = model
24
- self._name = name or self._get_random_name()
25
- self._state = "IDLE"
26
-
27
- @property
28
- def name(self):
29
- return self._name
30
-
31
- @property
32
- def state(self):
33
- return self._state
34
-
35
- @state.setter
36
- def state(self, state):
37
- self._state = state
38
-
39
- def infer(self, **kwargs):
40
- img_path = kwargs["img"]
41
- result = self._model.ocr(**kwargs)[0]
42
- image = Image.open(img_path).convert("RGB")
43
- boxes = [line[0] for line in result]
44
- txts = [line[1][0] for line in result]
45
- scores = [line[1][1] for line in result]
46
- im_show = draw_ocr(image, boxes, txts, scores,
47
- font_path="./simfang.ttf")
48
- return im_show
49
-
50
- def _get_random_name(self):
51
- return str(uuid.uuid4())
52
-
53
-
54
  class PaddleOCRModelManager(object):
55
  def __init__(self,
56
- num_models,
57
- model_factory,
58
- *,
59
- polling_interval=0.1):
60
  super().__init__()
61
- self._num_models = num_models
62
  self._model_factory = model_factory
63
- self._polling_interval = polling_interval
64
- self._models = {}
65
- self.new_models()
66
-
67
- def new_models(self):
68
- self._models.clear()
69
- for _ in range(self._num_models):
70
- model = self._new_model()
71
- self._models[model.name] = model
72
-
73
- async def infer(self, **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  while True:
75
- model = self._get_available_model()
76
- if not model:
77
- await asyncio.sleep(self._polling_interval)
78
- continue
79
- model.state = "RUNNING"
80
- # NOTE: I take an optimistic approach here, assuming that the model
81
- # is not broken even if inference fails.
82
  try:
83
- result = await self._new_inference_task(model, **kwargs)
 
 
 
84
  finally:
85
- model.state = "IDLE"
86
- return result
87
-
88
- def _new_model(self):
89
- real_model = self._model_factory()
90
- model = PaddleOCRModelWrapper(real_model)
91
- return model
92
-
93
- def _get_available_model(self):
94
- if not self._models:
95
- raise RuntimeError("No living models")
96
- for model in self._models.values():
97
- if model.state == "IDLE":
98
- return model
99
- return None
100
-
101
- def _new_inference_task(self, model,
102
- **kwargs):
103
- return asyncio.get_running_loop().run_in_executor(
104
- None, functools.partial(model.infer, **kwargs))
105
 
106
 
107
  def create_model(lang):
108
- return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
109
 
110
 
111
  model_managers = {}
@@ -114,10 +74,26 @@ for lang, config in LANG_CONFIG.items():
114
  model_managers[lang] = model_manager
115
 
116
 
117
- async def inference(img, lang):
 
 
 
 
 
 
 
 
 
118
  ocr = model_managers[lang]
119
- result = await ocr.infer(img=img, cls=True)
120
- return result
 
 
 
 
 
 
 
121
 
122
 
123
  title = 'PaddleOCR'
 
1
+ import atexit
2
  import functools
3
+ from queue import Queue
4
+ from threading import Thread
5
 
6
  from paddleocr import PaddleOCR, draw_ocr
7
  from PIL import Image
8
  import gradio as gr
9
 
10
+
11
  LANG_CONFIG = {
12
  "ch": {"num_workers": 4},
13
  "en": {"num_workers": 4},
 
19
  CONCURRENCY_LIMIT = 8
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  class PaddleOCRModelManager(object):
23
  def __init__(self,
24
+ num_workers,
25
+ model_factory):
 
 
26
  super().__init__()
 
27
  self._model_factory = model_factory
28
+ self._queue = Queue()
29
+ self._workers = []
30
+ for _ in range(num_workers):
31
+ worker = Thread(target=self._worker, daemon=False)
32
+ worker.start()
33
+ self._workers.append(worker)
34
+
35
+ def infer(self, *args, **kwargs):
36
+ # XXX: Should I use a more lightweight data structure, say, a future?
37
+ result_queue = Queue(maxsize=1)
38
+ self._queue.put((args, kwargs, result_queue))
39
+ success, payload = result_queue.get()
40
+ if success:
41
+ return payload
42
+ else:
43
+ raise payload
44
+
45
+ def close(self):
46
+ for _ in self._workers:
47
+ self._queue.put(None)
48
+ for worker in self._workers:
49
+ worker.join()
50
+
51
+ def _worker(self):
52
+ model = self._model_factory()
53
  while True:
54
+ item = self._queue.get()
55
+ if item is None:
56
+ break
57
+ args, kwargs, result_queue = item
 
 
 
58
  try:
59
+ result = model.ocr(*args, **kwargs)
60
+ result_queue.put((True, result))
61
+ except Exception as e:
62
+ result_queue.put((False, e))
63
  finally:
64
+ self._queue.task_done()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  def create_model(lang):
68
+ return PaddleOCR(lang=lang, use_angle_cls=True, use_gpu=False)
69
 
70
 
71
  model_managers = {}
 
74
  model_managers[lang] = model_manager
75
 
76
 
77
+ def close_model_managers():
78
+ for manager in model_managers.values():
79
+ manager.close()
80
+
81
+
82
+ # XXX: Not sure if gradio allows adding custom teardown logic
83
+ atexit.register(close_model_managers)
84
+
85
+
86
+ def inference(img, lang):
87
  ocr = model_managers[lang]
88
+ result = ocr.infer(img, cls=True)[0]
89
+ img_path = img
90
+ image = Image.open(img_path).convert("RGB")
91
+ boxes = [line[0] for line in result]
92
+ txts = [line[1][0] for line in result]
93
+ scores = [line[1][1] for line in result]
94
+ im_show = draw_ocr(image, boxes, txts, scores,
95
+ font_path="./simfang.ttf")
96
+ return im_show
97
 
98
 
99
  title = 'PaddleOCR'