sayed99 commited on
Commit
15f4de2
·
1 Parent(s): 5ab5f5e

image caption, image generation tool added

Browse files
Files changed (2) hide show
  1. tools/image_caption.py +23 -0
  2. tools/image_generation.py +52 -0
tools/image_caption.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client, handle_file
2
+ from typing import Any, Optional
3
+ from smolagents.tools import Tool
4
+
5
+
6
+ class ImageCaptionTool(Tool):
7
+ name = "image_caption"
8
+ description = "Provides a caption for the given image."
9
+ inputs = {'image_path': {'type': 'any',
10
+ 'description': 'The image path for which to generate a caption.'}}
11
+ output_type = "string"
12
+
13
+ def forward(self, image_path: Any) -> Any:
14
+ client = Client("hysts/image-captioning-with-blip")
15
+ result = client.predict(
16
+ image=handle_file(image_path),
17
+ text="A picture of",
18
+ api_name="/caption"
19
+ )
20
+ return result
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ self.is_initialized = False
tools/image_generation.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gradio_client import Client
2
+ from typing import Any
3
+ from smolagents.tools import Tool
4
+ import os
5
+ import shutil
6
+ from pathlib import Path
7
+ import uuid
8
+ from PIL import Image
9
+
10
+
11
+ class ImageGenerationTool(Tool):
12
+ name = "image_generation"
13
+ description = """
14
+ Generates an image based on the given prompt and saves it locally.
15
+
16
+ Args:
17
+ prompt (str): The prompt for image generation.
18
+
19
+ Returns:
20
+ tuple: A tuple containing:
21
+ - Image.Image: The generated image.
22
+ - pathlib.Path: The file path where the image was saved.
23
+ """
24
+ inputs = {'prompt': {'type': 'string',
25
+ 'description': 'The prompt for image generation.'}}
26
+ output_type = "any"
27
+
28
+ def forward(self, prompt: str) -> Any:
29
+ client = Client("mukaist/Midjourney")
30
+ result = client.predict(
31
+ prompt=prompt,
32
+ negative_prompt="(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck",
33
+ use_negative_prompt=True,
34
+ style="2560 x 1440",
35
+ seed=0,
36
+ width=1024,
37
+ height=1024,
38
+ guidance_scale=6,
39
+ randomize_seed=True,
40
+ api_name="/run"
41
+ )
42
+
43
+ image_path = result[0][0]['image']
44
+ save_path = Path(os.getcwd()) / "generations" / \
45
+ f"generated_image_{uuid.uuid4().hex}.png"
46
+ shutil.copy(image_path, save_path)
47
+
48
+ print(f"Image saved at: {save_path}")
49
+ return Image.open(save_path), save_path
50
+
51
+ def __init__(self, *args, **kwargs):
52
+ self.is_initialized = False