summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/image.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/g4f/image.py b/g4f/image.py
index 61081ea1..94b8c24c 100644
--- a/g4f/image.py
+++ b/g4f/image.py
@@ -20,23 +20,23 @@ def to_image(image: ImageType, is_svg: bool = False) -> Image.Image:
try:
import cairosvg
except ImportError:
- raise RuntimeError('Install "cairosvg" package for open svg images')
+ raise RuntimeError('Install "cairosvg" package for svg images')
if not isinstance(image, bytes):
image = image.read()
buffer = BytesIO()
cairosvg.svg2png(image, write_to=buffer)
- image = Image.open(buffer)
+ return Image.open(buffer)
if isinstance(image, str):
is_data_uri_an_image(image)
image = extract_data_uri(image)
if isinstance(image, bytes):
is_accepted_format(image)
- image = Image.open(BytesIO(image))
+ return Image.open(BytesIO(image))
elif not isinstance(image, Image.Image):
image = Image.open(image)
copy = image.copy()
copy.format = image.format
- image = copy
+ return copy
return image
def is_allowed_extension(filename: str) -> bool:
@@ -138,6 +138,7 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
Returns:
Image.Image: The processed image.
"""
+ # Fix orientation
orientation = get_orientation(img)
if orientation:
if orientation > 4:
@@ -148,7 +149,14 @@ def process_image(img: Image.Image, new_width: int, new_height: int) -> Image.Im
img = img.transpose(Image.ROTATE_270)
if orientation in [7, 8]:
img = img.transpose(Image.ROTATE_90)
+ # Resize image
img.thumbnail((new_width, new_height))
+ # Remove transparency
+ if img.mode != "RGB":
+ img.load()
+ white = Image.new('RGB', img.size, (255, 255, 255))
+ white.paste(img, mask=img.split()[3])
+ return white
return img
def to_base64(image: Image.Image, compression_rate: float) -> str:
@@ -163,8 +171,6 @@ def to_base64(image: Image.Image, compression_rate: float) -> str:
str: The base64-encoded image.
"""
output_buffer = BytesIO()
- if image.mode != "RGB":
- image = image.convert('RGB')
image.save(output_buffer, format="JPEG", quality=int(compression_rate * 100))
return base64.b64encode(output_buffer.getvalue()).decode()