brendenc commited on
Commit
479a349
·
1 Parent(s): 06ce617

Updated from colab

Browse files
Files changed (1) hide show
  1. app.py +1 -7
app.py CHANGED
@@ -8,17 +8,11 @@ import matplotlib.pyplot as plt
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
11
- collapse_categories = {**{i: 0 for i in range(1, 8)},
12
- **{i: 1 for i in range(8, 10)},
13
- **{i: 2 for i in range(10, 18)},
14
- **{i: 3 for i in range(18, 28)}}
15
-
16
  def classify(im):
17
  inputs = extractor(images=im, return_tensors="pt")
18
  outputs = model(**inputs)
19
  logits = outputs.logits
20
  classes = logits[0].detach().numpy().argmax(axis=0)
21
- classes = np.vectorize(lambda x: collapse_categories.get(x, 4))(classes)
22
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
23
  return colors[classes]
24
 
@@ -26,8 +20,8 @@ example_imgs = [f"example_{i}.jpg" for i in range(3)]
26
  interface = gr.Interface(classify,
27
  inputs="image",
28
  outputs="image",
29
- title = "Street Image Segmentation",
30
  examples = example_imgs,
 
31
  description = """Below is a simple app for image segmentation. This model was trained using""")
32
 
33
  interface.launch(debug=True)
 
8
  extractor = AutoFeatureExtractor.from_pretrained("brendenc/my-segmentation-model")
9
  model = AutoModelForImageClassification.from_pretrained("brendenc/my-segmentation-model")
10
 
 
 
 
 
 
11
  def classify(im):
12
  inputs = extractor(images=im, return_tensors="pt")
13
  outputs = model(**inputs)
14
  logits = outputs.logits
15
  classes = logits[0].detach().numpy().argmax(axis=0)
 
16
  colors = np.array([[128,0,0], [128,128,0], [0, 0, 128], [128,0,128], [0, 0, 0]])
17
  return colors[classes]
18
 
 
20
  interface = gr.Interface(classify,
21
  inputs="image",
22
  outputs="image",
 
23
  examples = example_imgs,
24
+ title = "Street Image Segmentation",
25
  description = """Below is a simple app for image segmentation. This model was trained using""")
26
 
27
  interface.launch(debug=True)