diff --git a/senju/image_reco.py b/senju/image_reco.py index e3a207c..66ddafe 100644 --- a/senju/image_reco.py +++ b/senju/image_reco.py @@ -7,10 +7,12 @@ from transformers import BlipProcessor, BlipForConditionalGeneration class ImageDescriptionGenerator: def __init__(self, model_name="Salesforce/blip-image-captioning-base"): """ - Initialize an image description generator using a vision-language model. + Initialize an image description generator using a vision-language + model. Args: - model_name: The name of the model to use (default: BLIP captioning model) + model_name: The name of the model to use + (default: BLIP captioning model) """ self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {self.device}") @@ -27,7 +29,8 @@ class ImageDescriptionGenerator: max_length: Maximum length of the generated caption Returns: - dict: A dictionary containing the generated description and confidence score + dict: A dictionary containing the generated description + and confidence score """ # Convert uploaded bytes to image img = Image.open(io.BytesIO(image_data)).convert("RGB") @@ -52,7 +55,7 @@ class ImageDescriptionGenerator: return { "description": caption, - "confidence": None # Most caption models don't provide confidence scores + "confidence": None }