summaryrefslogtreecommitdiffstats
path: root/g4f
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--g4f/__init__.py16
1 files changed, 9 insertions, 7 deletions
diff --git a/g4f/__init__.py b/g4f/__init__.py
index faef7923..2c9ef7d7 100644
--- a/g4f/__init__.py
+++ b/g4f/__init__.py
@@ -1,8 +1,8 @@
from __future__ import annotations
from requests import get
from .models import Model, ModelUtils, _all_models
-from .Provider import BaseProvider, RetryProvider
-from .typing import Messages, CreateResult, Union, List
+from .Provider import BaseProvider, AsyncGeneratorProvider, RetryProvider
+from .typing import Messages, CreateResult, AsyncResult, Union, List
from . import debug
version = '0.1.8.7'
@@ -80,13 +80,15 @@ class ChatCompletion:
messages : Messages,
provider : Union[type[BaseProvider], None] = None,
stream : bool = False,
- ignored : List[str] = None, **kwargs) -> str:
-
- if stream:
- raise ValueError('"create_async" does not support "stream" argument')
-
+ ignored : List[str] = None,
+ **kwargs) -> Union[AsyncResult, str]:
model, provider = get_model_and_provider(model, provider, False, ignored)
+ if stream:
+ if isinstance(provider, type) and issubclass(provider, AsyncGeneratorProvider):
+ return await provider.create_async_generator(model.name, messages, **kwargs)
+ raise ValueError(f'{provider.__name__} does not support "stream" argument')
+
return await provider.create_async(model.name, messages, **kwargs)
class Completion: