diff --git a/pyproject.toml b/pyproject.toml index 5a944ea..a10ccd8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "coverage (>=7.6.12,<8.0.0)", "pytest-httpserver (>=1.1.2,<2.0.0)", "pillow (>=11.1.0,<12.0.0)", - "tensorflow (>=2.19.0,<3.0.0)", + "torch (>=2.6.0,<3.0.0)", + "transformers (>=4.50.0,<5.0.0)", ] diff --git a/senju/image_reco.py b/senju/image_reco.py index f8e82fd..e3a207c 100644 --- a/senju/image_reco.py +++ b/senju/image_reco.py @@ -1,40 +1,63 @@ -import numpy as np +import torch from PIL import Image -import keras import io - -g_model = None +from transformers import BlipProcessor, BlipForConditionalGeneration -class SimpleClassifier: - def __init__(self): - global g_model - if g_model is None: - g_model = keras.applications.MobileNetV2(weights="imagenet") - self.model = g_model +class ImageDescriptionGenerator: + def __init__(self, model_name="Salesforce/blip-image-captioning-base"): + """ + Initialize an image description generator using a vision-language model. - def classify(self, image_data): + Args: + model_name: The name of the model to use (default: BLIP captioning model) + """ + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"Using device: {self.device}") + + self.processor = BlipProcessor.from_pretrained(model_name) + self.model = BlipForConditionalGeneration.from_pretrained(model_name) + + def generate_description(self, image_data, max_length=50): + """ + Generate a descriptive caption for the given image. + + Args: + image_data: Raw image data (bytes) + max_length: Maximum length of the generated caption + + Returns: + dict: A dictionary containing the generated description and confidence score + """ # Convert uploaded bytes to image img = Image.open(io.BytesIO(image_data)).convert("RGB") - img = img.resize((224, 224)) - img_array = np.array(img) / 255.0 - img_array = np.expand_dims(img_array, axis=0) - # Get predictions - predictions = self.model.predict(img_array) - results = keras.applications.mobilenet_v2.decode_predictions( - predictions, top=5)[0] + # Process the image + inputs = self.processor( + images=img, return_tensors="pt").to(self.device) - data: dict = {} - all_labels: list[dict] = [] - data["best_guess"] = {"label": "", "confidence": float(0)} - for _, label, score in results: - score = float(score) - datapoint = {"label": label, "confidence": score} - all_labels.append(datapoint) - if data["best_guess"]["confidence"] < score: - data["best_guess"] = datapoint + # Generate caption + with torch.no_grad(): + output = self.model.generate( + **inputs, + max_length=max_length, + num_beams=5, + num_return_sequences=1, + temperature=1.0, + do_sample=False + ) - data["all"] = all_labels + # Decode the caption + caption = self.processor.decode(output[0], skip_special_tokens=True) - return data + return { + "description": caption, + "confidence": None # Most caption models don't provide confidence scores + } + + +g_descriptor: ImageDescriptionGenerator = ImageDescriptionGenerator() + + +def gen_response(image_data) -> dict: + return g_descriptor.generate_description(image_data) diff --git a/senju/main.py b/senju/main.py index 63154d1..4390363 100644 --- a/senju/main.py +++ b/senju/main.py @@ -6,7 +6,7 @@ from flask import (Flask, redirect, render_template, request, url_for, send_from_directory) from senju.haiku import Haiku -from senju.image_reco import SimpleClassifier +from senju.image_reco import gen_response from senju.store_manager import StoreManager import os @@ -64,7 +64,6 @@ def scan_view(): @app.route("/api/v1/image_reco", methods=['POST']) def image_recognition(): # note that the classifier is a singleton - classifier = SimpleClassifier() if 'image' not in request.files: return "No image file provided", 400 @@ -72,7 +71,7 @@ def image_recognition(): image_data = image_file.read() try: - results = classifier.classify(image_data) + results = gen_response(image_data) return results except Exception as e: return str(e), 500