slotab commited on
Commit
10bdd43
·
1 Parent(s): d3b8e67

AutoImageProcessor mobilenet

Browse files
Files changed (2) hide show
  1. app.py +30 -22
  2. requirements.txt +48 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  from PIL import Image
4
  import torch
5
  import numpy
 
6
 
7
  from transformers import DetrImageProcessor, DetrForSegmentation, AutoImageProcessor, AutoModelForImageClassification
8
  from transformers.models.detr.feature_extraction_detr import rgb_to_id
@@ -10,29 +11,36 @@ from transformers.models.detr.feature_extraction_detr import rgb_to_id
10
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
11
  image = Image.open(requests.get(url, stream=True).raw)
12
 
13
- feature_extractor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50-panoptic")
14
- model = DetrForSegmentation.from_pretrained("facebook/detr-resnet-50-panoptic")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- # prepare image for the model
17
- inputs = feature_extractor(images=image, return_tensors="pt")
18
-
19
- # forward pass
20
  outputs = model(**inputs)
 
21
 
22
- # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
23
- processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
24
- result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
25
-
26
- # the segmentation is stored in a special-format png
27
- panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
28
- panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
29
- # retrieve the ids corresponding to each mask
30
- panoptic_seg_id = rgb_to_id(panoptic_seg)
31
-
32
-
33
-
34
-
35
- # preprocessor = AutoImageProcessor.from_pretrained("google/mobilenet_v2_1.0_224")
36
- # model = AutoModelForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
37
 
38
- # inputs = preprocessor(images=image, return_tensors="pt")
 
3
  from PIL import Image
4
  import torch
5
  import numpy
6
+ import gradio as gr
7
 
8
  from transformers import DetrImageProcessor, DetrForSegmentation, AutoImageProcessor, AutoModelForImageClassification
9
  from transformers.models.detr.feature_extraction_detr import rgb_to_id
 
11
  url = "http://images.cocodataset.org/val2017/000000039769.jpg"
12
  image = Image.open(requests.get(url, stream=True).raw)
13
 
14
+ # feature_extractor = DetrImageProcessor.from_pretrained("facebook/post_process_panoptic_segmentation")
15
+ # model = DetrForSegmentation.from_pretrained("facebook/post_process_panoptic_segmentation")
16
+ #
17
+ # # prepare image for the model
18
+ # inputs = feature_extractor(images=image, return_tensors="pt")
19
+ #
20
+ # # forward pass
21
+ # outputs = model(**inputs)
22
+ #
23
+ # # use the `post_process_panoptic` method of `DetrFeatureExtractor` to convert to COCO format
24
+ # processed_sizes = torch.as_tensor(inputs["pixel_values"].shape[-2:]).unsqueeze(0)
25
+ # result = feature_extractor.post_process_panoptic(outputs, processed_sizes)[0]
26
+ #
27
+ # # the segmentation is stored in a special-format png
28
+ # panoptic_seg = Image.open(io.BytesIO(result["png_string"]))
29
+ # panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
30
+ # # retrieve the ids corresponding to each mask
31
+ # panoptic_seg_id = rgb_to_id(panoptic_seg)
32
+
33
+
34
+ preprocessor = AutoImageProcessor.from_pretrained("google/mobilenet_v2_1.0_224")
35
+ model = AutoModelForImageClassification.from_pretrained("google/mobilenet_v2_1.0_224")
36
+
37
+ inputs = preprocessor(images=image, return_tensors="pt")
38
 
 
 
 
 
39
  outputs = model(**inputs)
40
+ logits = outputs.logits
41
 
42
+ # model predicts one of the 1000 ImageNet classes
43
+ predicted_class_idx = logits.argmax(-1).item()
44
+ print("Predicted class:", model.config.id2label[predicted_class_idx])
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # gr.Image(image).launch()
requirements.txt CHANGED
@@ -1,26 +1,74 @@
 
 
 
1
  certifi==2024.7.4
2
  charset-normalizer==3.3.2
 
 
 
 
 
 
 
 
3
  filelock==3.15.4
 
4
  fsspec==2024.6.1
 
 
 
 
 
 
5
  huggingface-hub==0.24.2
6
  idna==3.7
 
7
  Jinja2==3.1.4
 
 
8
  MarkupSafe==2.1.5
 
 
9
  mpmath==1.3.0
10
  networkx==3.3
11
  numpy==2.0.1
 
12
  packaging==24.1
 
13
  pillow==10.4.0
 
 
 
 
 
 
 
 
 
14
  PyYAML==6.0.1
15
  regex==2024.7.24
16
  requests==2.32.3
 
 
17
  safetensors==0.4.3
 
 
 
 
 
18
  sympy==1.13.1
19
  timm==1.0.7
20
  tokenizers==0.19.1
 
21
  torch==2.4.0
22
  torchvision==0.19.0
23
  tqdm==4.66.4
24
  transformers==4.43.3
 
25
  typing_extensions==4.12.2
 
26
  urllib3==2.2.2
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.4.0
4
  certifi==2024.7.4
5
  charset-normalizer==3.3.2
6
+ click==8.1.7
7
+ contourpy==1.2.1
8
+ cycler==0.12.1
9
+ dnspython==2.6.1
10
+ email_validator==2.2.0
11
+ fastapi==0.111.1
12
+ fastapi-cli==0.0.4
13
+ ffmpy==0.3.2
14
  filelock==3.15.4
15
+ fonttools==4.53.1
16
  fsspec==2024.6.1
17
+ gradio==4.39.0
18
+ gradio_client==1.1.1
19
+ h11==0.14.0
20
+ httpcore==1.0.5
21
+ httptools==0.6.1
22
+ httpx==0.27.0
23
  huggingface-hub==0.24.2
24
  idna==3.7
25
+ importlib_resources==6.4.0
26
  Jinja2==3.1.4
27
+ kiwisolver==1.4.5
28
+ markdown-it-py==3.0.0
29
  MarkupSafe==2.1.5
30
+ matplotlib==3.9.1
31
+ mdurl==0.1.2
32
  mpmath==1.3.0
33
  networkx==3.3
34
  numpy==2.0.1
35
+ orjson==3.10.6
36
  packaging==24.1
37
+ pandas==2.2.2
38
  pillow==10.4.0
39
+ pydantic==2.8.2
40
+ pydantic_core==2.20.1
41
+ pydub==0.25.1
42
+ Pygments==2.18.0
43
+ pyparsing==3.1.2
44
+ python-dateutil==2.9.0.post0
45
+ python-dotenv==1.0.1
46
+ python-multipart==0.0.9
47
+ pytz==2024.1
48
  PyYAML==6.0.1
49
  regex==2024.7.24
50
  requests==2.32.3
51
+ rich==13.7.1
52
+ ruff==0.5.5
53
  safetensors==0.4.3
54
+ semantic-version==2.10.0
55
+ shellingham==1.5.4
56
+ six==1.16.0
57
+ sniffio==1.3.1
58
+ starlette==0.37.2
59
  sympy==1.13.1
60
  timm==1.0.7
61
  tokenizers==0.19.1
62
+ tomlkit==0.12.0
63
  torch==2.4.0
64
  torchvision==0.19.0
65
  tqdm==4.66.4
66
  transformers==4.43.3
67
+ typer==0.12.3
68
  typing_extensions==4.12.2
69
+ tzdata==2024.1
70
  urllib3==2.2.2
71
+ uvicorn==0.30.3
72
+ uvloop==0.19.0
73
+ watchfiles==0.22.0
74
+ websockets==11.0.3