feat: generate a good description of an image

Refs: OPS-85
This commit is contained in:
Christoph J. Scherr 2025-03-23 17:00:47 +01:00
parent 47d4d9e4b9
commit 665f69987b
No known key found for this signature in database
GPG key ID: 9EB784BB202BB7BB
3 changed files with 56 additions and 33 deletions

View file

@ -17,7 +17,8 @@ dependencies = [
"coverage (>=7.6.12,<8.0.0)", "coverage (>=7.6.12,<8.0.0)",
"pytest-httpserver (>=1.1.2,<2.0.0)", "pytest-httpserver (>=1.1.2,<2.0.0)",
"pillow (>=11.1.0,<12.0.0)", "pillow (>=11.1.0,<12.0.0)",
"tensorflow (>=2.19.0,<3.0.0)", "torch (>=2.6.0,<3.0.0)",
"transformers (>=4.50.0,<5.0.0)",
] ]

View file

@ -1,40 +1,63 @@
import numpy as np import torch
from PIL import Image from PIL import Image
import keras
import io import io
from transformers import BlipProcessor, BlipForConditionalGeneration
g_model = None
class SimpleClassifier: class ImageDescriptionGenerator:
def __init__(self): def __init__(self, model_name="Salesforce/blip-image-captioning-base"):
global g_model """
if g_model is None: Initialize an image description generator using a vision-language model.
g_model = keras.applications.MobileNetV2(weights="imagenet")
self.model = g_model
def classify(self, image_data): Args:
model_name: The name of the model to use (default: BLIP captioning model)
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForConditionalGeneration.from_pretrained(model_name)
def generate_description(self, image_data, max_length=50):
"""
Generate a descriptive caption for the given image.
Args:
image_data: Raw image data (bytes)
max_length: Maximum length of the generated caption
Returns:
dict: A dictionary containing the generated description and confidence score
"""
# Convert uploaded bytes to image # Convert uploaded bytes to image
img = Image.open(io.BytesIO(image_data)).convert("RGB") img = Image.open(io.BytesIO(image_data)).convert("RGB")
img = img.resize((224, 224))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
# Get predictions # Process the image
predictions = self.model.predict(img_array) inputs = self.processor(
results = keras.applications.mobilenet_v2.decode_predictions( images=img, return_tensors="pt").to(self.device)
predictions, top=5)[0]
data: dict = {} # Generate caption
all_labels: list[dict] = [] with torch.no_grad():
data["best_guess"] = {"label": "", "confidence": float(0)} output = self.model.generate(
for _, label, score in results: **inputs,
score = float(score) max_length=max_length,
datapoint = {"label": label, "confidence": score} num_beams=5,
all_labels.append(datapoint) num_return_sequences=1,
if data["best_guess"]["confidence"] < score: temperature=1.0,
data["best_guess"] = datapoint do_sample=False
)
data["all"] = all_labels # Decode the caption
caption = self.processor.decode(output[0], skip_special_tokens=True)
return data return {
"description": caption,
"confidence": None # Most caption models don't provide confidence scores
}
g_descriptor: ImageDescriptionGenerator = ImageDescriptionGenerator()
def gen_response(image_data) -> dict:
return g_descriptor.generate_description(image_data)

View file

@ -6,7 +6,7 @@ from flask import (Flask, redirect, render_template, request, url_for,
send_from_directory) send_from_directory)
from senju.haiku import Haiku from senju.haiku import Haiku
from senju.image_reco import SimpleClassifier from senju.image_reco import gen_response
from senju.store_manager import StoreManager from senju.store_manager import StoreManager
import os import os
@ -64,7 +64,6 @@ def scan_view():
@app.route("/api/v1/image_reco", methods=['POST']) @app.route("/api/v1/image_reco", methods=['POST'])
def image_recognition(): def image_recognition():
# note that the classifier is a singleton # note that the classifier is a singleton
classifier = SimpleClassifier()
if 'image' not in request.files: if 'image' not in request.files:
return "No image file provided", 400 return "No image file provided", 400
@ -72,7 +71,7 @@ def image_recognition():
image_data = image_file.read() image_data = image_file.read()
try: try:
results = classifier.classify(image_data) results = gen_response(image_data)
return results return results
except Exception as e: except Exception as e:
return str(e), 500 return str(e), 500