Skip to content

Commit

Permalink
Handle Gemini output types solely in GeminiTypeAdapter
Browse files Browse the repository at this point in the history
We are currently using intermediate `Json`, `Choice` and `List` classes
which, in addition to being cumbersome, are not useful abstractions.
This commit bypasses these abstractions to handle the output types
directly at the level of `GeminiTypeAdapter`.
  • Loading branch information
rlouf committed Feb 23, 2025
1 parent 1af6d3f commit f6e021c
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
65 changes: 41 additions & 24 deletions outlines/models/gemini.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
"""Integration with Gemini's API."""

from enum import EnumMeta
from functools import singledispatchmethod
from types import NoneType
from typing import Optional, Union
from typing import Optional, Union, Any, GenericAlias

from pydantic import BaseModel
from typing_extensions import _TypedDictMeta # type: ignore

from outlines.models.base import Model, ModelTypeAdapter
from outlines.prompts import Vision
from outlines.types import Choice, Json, List

__all__ = ["Gemini"]

Expand Down Expand Up @@ -55,40 +55,57 @@ def format_vision_input(self, model_input: Vision):

@singledispatchmethod
def format_output_type(self, output_type):
raise NotImplementedError
match output_type:
case GenericAlias():
return self.format_list_output_type(output_type)
case _:
raise TypeError(
f"The type {output_type} is not supported by the Gemini API. "
"You can use a local model or dottxt instead."
)

@format_output_type.register(List)
def format_list_output_type(self, output_type):
return {
"response_mime_type": "application/json",
"response_schema": list[output_type.definition.definition],
}

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type):
return {}

@format_output_type.register(Json)
def format_json_output_type(self, output_type):
"""Gemini only accepts Pydantic models and TypeDicts to define the JSON structure."""
if issubclass(output_type.definition, BaseModel):
args = output_type.__args__
if len(args) == 1 and issubclass(args[0], BaseModel):
return {
"response_mime_type": "application/json",
"response_schema": output_type.definition,
"response_schema": output_type,
}
elif isinstance(output_type.definition, _TypedDictMeta):
elif len(args) == 1 and issubsclass(args[0], _TypedDictMeta):
return {
"response_mime_type": "application/json",
"response_schema": output_type.definition,
"response_schema": output_type,
}
else:
raise NotImplementedError
raise TypeError(
"Gemini models only support lists of pydantic `BaseModel`s or `TypedDict`s "
f"You passed {output_type} instead. Chances are the output type you are "
"trying to define is currently only supported by local models or dottxt."
)

@format_output_type.register(NoneType)
def format_none_output_type(self, output_type):
return {}

@format_output_type.register(_TypedDictMeta)
def format_json_typeddict_type(self, output_type):
return {
"response_mime_type": "application/json",
"response_schema": output_type,
}

@format_output_type.register(type(BaseModel))
def format_json_pydantic_type(self, output_type):
return {
"response_mime_type": "application/json",
"response_schema": output_type,
}

@format_output_type.register(Choice)
@format_output_type.register(EnumMeta)
def format_enum_output_type(self, output_type):
return {
"response_mime_type": "text/x.enum",
"response_schema": output_type.definition,
"response_schema": output_type,
}


Expand All @@ -103,7 +120,7 @@ def __init__(self, model_name: str, *args, **kwargs):
def generate(
self,
model_input: Union[str, Vision],
output_type: Optional[Union[Json, EnumMeta]] = None,
output_type: Optional[Any] = None,
**inference_kwargs,
):
import google.generativeai as genai
Expand Down
21 changes: 12 additions & 9 deletions tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from outlines.models.gemini import Gemini
from outlines.prompts import Vision
from outlines.types import Choice, Json, List
from outlines.types import List

MODEL_NAME = "gemini-1.5-flash-latest"

Expand Down Expand Up @@ -60,7 +60,7 @@ def test_gemini_simple_pydantic():
class Foo(BaseModel):
bar: int

result = model.generate("foo?", Json(Foo))
result = model.generate("foo?", Foo)
assert isinstance(result, str)
assert "bar" in json.loads(result)

Expand Down Expand Up @@ -95,7 +95,7 @@ class Foo(BaseModel):
sna: int
bar: Bar

result = model.generate("foo?", Json(Foo))
result = model.generate("foo?", Foo)
assert isinstance(result, str)
assert "sna" in json.loads(result)
assert "bar" in json.loads(result)
Expand All @@ -115,7 +115,7 @@ def test_gemini_simple_json_schema_dict():
"title": "Foo",
"type": "object",
}
result = model.generate("foo?", Json(schema))
result = model.generate("foo?", schema)
assert isinstance(result, str)
assert "bar" in json.loads(result)

Expand All @@ -128,7 +128,7 @@ def test_gemini_simple_json_schema_string():
model = Gemini(MODEL_NAME)

schema = "{'properties': {'bar': {'title': 'Bar', 'type': 'integer'}}, 'required': ['bar'], 'title': 'Foo', 'type': 'object'}"
result = model.generate("foo?", Json(schema))
result = model.generate("foo?", schema)
assert isinstance(result, str)
assert "bar" in json.loads(result)

Expand All @@ -140,7 +140,7 @@ def test_gemini_simple_typed_dict():
class Foo(TypedDict):
bar: int

result = model.generate("foo?", Json(Foo))
result = model.generate("foo?", Foo)
assert isinstance(result, str)
assert "bar" in json.loads(result)

Expand All @@ -153,17 +153,20 @@ class Foo(Enum):
bar = "Bar"
foor = "Foo"

result = model.generate("foo?", Choice(Foo))
result = model.generate("foo?", Foo)
assert isinstance(result, str)
assert result == "Foo" or result == "Bar"


@pytest.mark.xfail(
reason="Gemini supports lists for choices but we do not as it is semantically incorrect."
)
@pytest.mark.api_call
def test_gemini_simple_choice_list():
model = Gemini(MODEL_NAME)

choices = ["Foo", "Bar"]
result = model.generate("foo?", Choice(choices))
result = model.generate("foo?", choices)
assert isinstance(result, str)
assert result == "Foo" or result == "Bar"

Expand All @@ -175,7 +178,7 @@ def test_gemini_simple_list_pydantic():
class Foo(BaseModel):
bar: int

result = model.generate("foo?", List(Json(Foo)))
result = model.generate("foo?", list[Foo])
assert isinstance(json.loads(result), list)
assert isinstance(json.loads(result)[0], dict)
assert "bar" in json.loads(result)[0]

0 comments on commit f6e021c

Please sign in to comment.