diff --git a/pyproject.toml b/pyproject.toml index 688e449..5a944ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,9 +2,12 @@ name = "senju" version = "0.1.0" description = "API / Webservice for Phrases/Words/Kanji/Haiku" -authors = [{ name = "Christoph J. Scherr", email = "software@cscherr.de" },{name = "Moritz Marquard", email="mrmarquard@protonmail.com"}] +authors = [ + { name = "Christoph J. Scherr", email = "software@cscherr.de" }, + { name = "Moritz Marquard", email = "mrmarquard@protonmail.com" }, +] readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.10,<3.13" dependencies = [ "jinja2 (>=3.1.5,<4.0.0)", "pytest>=7.0.0", @@ -13,6 +16,8 @@ dependencies = [ "requests (>=2.32.3,<3.0.0)", "coverage (>=7.6.12,<8.0.0)", "pytest-httpserver (>=1.1.2,<2.0.0)", + "pillow (>=11.1.0,<12.0.0)", + "tensorflow (>=2.19.0,<3.0.0)", ] diff --git a/senju/image_reco.py b/senju/image_reco.py new file mode 100644 index 0000000..f8e82fd --- /dev/null +++ b/senju/image_reco.py @@ -0,0 +1,40 @@ +import numpy as np +from PIL import Image +import keras +import io + +g_model = None + + +class SimpleClassifier: + def __init__(self): + global g_model + if g_model is None: + g_model = keras.applications.MobileNetV2(weights="imagenet") + self.model = g_model + + def classify(self, image_data): + # Convert uploaded bytes to image + img = Image.open(io.BytesIO(image_data)).convert("RGB") + img = img.resize((224, 224)) + img_array = np.array(img) / 255.0 + img_array = np.expand_dims(img_array, axis=0) + + # Get predictions + predictions = self.model.predict(img_array) + results = keras.applications.mobilenet_v2.decode_predictions( + predictions, top=5)[0] + + data: dict = {} + all_labels: list[dict] = [] + data["best_guess"] = {"label": "", "confidence": float(0)} + for _, label, score in results: + score = float(score) + datapoint = {"label": label, "confidence": score} + all_labels.append(datapoint) + if data["best_guess"]["confidence"] < score: + data["best_guess"] = datapoint + + data["all"] = all_labels + + return data diff --git a/senju/main.py b/senju/main.py index 4f16689..63154d1 100644 --- a/senju/main.py +++ b/senju/main.py @@ -6,6 +6,7 @@ from flask import (Flask, redirect, render_template, request, url_for, send_from_directory) from senju.haiku import Haiku +from senju.image_reco import SimpleClassifier from senju.store_manager import StoreManager import os @@ -60,6 +61,23 @@ def scan_view(): ) +@app.route("/api/v1/image_reco", methods=['POST']) +def image_recognition(): + # note that the classifier is a singleton + classifier = SimpleClassifier() + if 'image' not in request.files: + return "No image file provided", 400 + + image_file = request.files['image'] + image_data = image_file.read() + + try: + results = classifier.classify(image_data) + return results + except Exception as e: + return str(e), 500 + + @app.route("/api/v1/haiku", methods=['POST']) def generate_haiku(): if request.method == 'POST': diff --git a/senju/static/js/scan.js b/senju/static/js/scan.js index ad64ed6..8a0275a 100644 --- a/senju/static/js/scan.js +++ b/senju/static/js/scan.js @@ -56,15 +56,46 @@ function handleSubmit() { // Hide error errorMessage.classList.add("hidden"); - // Show response box + // Show loading state + document.getElementById("ai-response").textContent = "Analyzing image..."; responseBox.classList.remove("opacity-0"); - // Example response - document.getElementById("ai-response").textContent = - "Dominic Monaghan interviewing Elijah Wood if he will wear wigs"; + // Get the file from the input + const file = dropzoneFile.files[0]; + + // Create FormData object to send the file + const formData = new FormData(); + formData.append("image", file); + + // Send the image to your backend API + fetch("/api/v1/image_reco", { + method: "POST", + body: formData, + }) + .then((response) => { + if (!response.ok) { + throw new Error("Network response was not ok"); + } + return response.json(); + }) + .then((data) => { + // Extract top result and display it + if (data.results && data.results.length > 0) { + const topResult = data.results[0]; + document.getElementById("ai-response").textContent = + `${topResult.label} (${Math.round(topResult.confidence * 100)}% confidence)`; + } else { + document.getElementById("ai-response").textContent = + "Could not identify the image"; + } + }) + .catch((error) => { + console.error("Error:", error); + document.getElementById("ai-response").textContent = + "Error analyzing image"; + }); } else { errorMessage.classList.remove("hidden"); - uploadArea.classList.add("shake"); setTimeout(() => { uploadArea.classList.remove("shake");