mirror of
https://github.com/senju1337/senju.git
synced 2025-12-24 07:39:29 +00:00
feat: generate a good description of an image
Refs: OPS-85
This commit is contained in:
parent
47d4d9e4b9
commit
665f69987b
3 changed files with 56 additions and 33 deletions
|
|
@ -17,7 +17,8 @@ dependencies = [
|
||||||
"coverage (>=7.6.12,<8.0.0)",
|
"coverage (>=7.6.12,<8.0.0)",
|
||||||
"pytest-httpserver (>=1.1.2,<2.0.0)",
|
"pytest-httpserver (>=1.1.2,<2.0.0)",
|
||||||
"pillow (>=11.1.0,<12.0.0)",
|
"pillow (>=11.1.0,<12.0.0)",
|
||||||
"tensorflow (>=2.19.0,<3.0.0)",
|
"torch (>=2.6.0,<3.0.0)",
|
||||||
|
"transformers (>=4.50.0,<5.0.0)",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,40 +1,63 @@
|
||||||
import numpy as np
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import keras
|
|
||||||
import io
|
import io
|
||||||
|
from transformers import BlipProcessor, BlipForConditionalGeneration
|
||||||
g_model = None
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleClassifier:
|
class ImageDescriptionGenerator:
|
||||||
def __init__(self):
|
def __init__(self, model_name="Salesforce/blip-image-captioning-base"):
|
||||||
global g_model
|
"""
|
||||||
if g_model is None:
|
Initialize an image description generator using a vision-language model.
|
||||||
g_model = keras.applications.MobileNetV2(weights="imagenet")
|
|
||||||
self.model = g_model
|
|
||||||
|
|
||||||
def classify(self, image_data):
|
Args:
|
||||||
|
model_name: The name of the model to use (default: BLIP captioning model)
|
||||||
|
"""
|
||||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Using device: {self.device}")
|
||||||
|
|
||||||
|
self.processor = BlipProcessor.from_pretrained(model_name)
|
||||||
|
self.model = BlipForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
|
||||||
|
def generate_description(self, image_data, max_length=50):
|
||||||
|
"""
|
||||||
|
Generate a descriptive caption for the given image.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_data: Raw image data (bytes)
|
||||||
|
max_length: Maximum length of the generated caption
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the generated description and confidence score
|
||||||
|
"""
|
||||||
# Convert uploaded bytes to image
|
# Convert uploaded bytes to image
|
||||||
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
img = Image.open(io.BytesIO(image_data)).convert("RGB")
|
||||||
img = img.resize((224, 224))
|
|
||||||
img_array = np.array(img) / 255.0
|
|
||||||
img_array = np.expand_dims(img_array, axis=0)
|
|
||||||
|
|
||||||
# Get predictions
|
# Process the image
|
||||||
predictions = self.model.predict(img_array)
|
inputs = self.processor(
|
||||||
results = keras.applications.mobilenet_v2.decode_predictions(
|
images=img, return_tensors="pt").to(self.device)
|
||||||
predictions, top=5)[0]
|
|
||||||
|
|
||||||
data: dict = {}
|
# Generate caption
|
||||||
all_labels: list[dict] = []
|
with torch.no_grad():
|
||||||
data["best_guess"] = {"label": "", "confidence": float(0)}
|
output = self.model.generate(
|
||||||
for _, label, score in results:
|
**inputs,
|
||||||
score = float(score)
|
max_length=max_length,
|
||||||
datapoint = {"label": label, "confidence": score}
|
num_beams=5,
|
||||||
all_labels.append(datapoint)
|
num_return_sequences=1,
|
||||||
if data["best_guess"]["confidence"] < score:
|
temperature=1.0,
|
||||||
data["best_guess"] = datapoint
|
do_sample=False
|
||||||
|
)
|
||||||
|
|
||||||
data["all"] = all_labels
|
# Decode the caption
|
||||||
|
caption = self.processor.decode(output[0], skip_special_tokens=True)
|
||||||
|
|
||||||
return data
|
return {
|
||||||
|
"description": caption,
|
||||||
|
"confidence": None # Most caption models don't provide confidence scores
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
g_descriptor: ImageDescriptionGenerator = ImageDescriptionGenerator()
|
||||||
|
|
||||||
|
|
||||||
|
def gen_response(image_data) -> dict:
|
||||||
|
return g_descriptor.generate_description(image_data)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from flask import (Flask, redirect, render_template, request, url_for,
|
||||||
send_from_directory)
|
send_from_directory)
|
||||||
|
|
||||||
from senju.haiku import Haiku
|
from senju.haiku import Haiku
|
||||||
from senju.image_reco import SimpleClassifier
|
from senju.image_reco import gen_response
|
||||||
from senju.store_manager import StoreManager
|
from senju.store_manager import StoreManager
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -64,7 +64,6 @@ def scan_view():
|
||||||
@app.route("/api/v1/image_reco", methods=['POST'])
|
@app.route("/api/v1/image_reco", methods=['POST'])
|
||||||
def image_recognition():
|
def image_recognition():
|
||||||
# note that the classifier is a singleton
|
# note that the classifier is a singleton
|
||||||
classifier = SimpleClassifier()
|
|
||||||
if 'image' not in request.files:
|
if 'image' not in request.files:
|
||||||
return "No image file provided", 400
|
return "No image file provided", 400
|
||||||
|
|
||||||
|
|
@ -72,7 +71,7 @@ def image_recognition():
|
||||||
image_data = image_file.read()
|
image_data = image_file.read()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
results = classifier.classify(image_data)
|
results = gen_response(image_data)
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return str(e), 500
|
return str(e), 500
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue