summaryrefslogtreecommitdiffstats
path: root/g4f/Provider/Pi.py
blob: 9ecebafb9c1fae33bab6a02792324bd279dc8426 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

from ..typing import CreateResult, Messages
from .base_provider import BaseProvider, format_prompt

import json
from cloudscraper import CloudScraper, session, create_scraper

class Pi(BaseProvider):
    url             = "https://chat-gpt.com"
    working         = True
    supports_stream = True

    @classmethod
    def create_completion(
        cls,
        model: str,
        messages: Messages,
        stream: bool,
        proxy: str = None,
        scraper: CloudScraper = None,
        conversation: dict = None,
        **kwargs
    ) -> CreateResult:
        if not scraper:
            scraper = cls.get_scraper()
        if not conversation:
            conversation = cls.start_conversation(scraper)
        answer = cls.ask(scraper, messages, conversation)

        last_answer = 0
        for line in answer:
            if "text" in line:
                yield line["text"][last_answer:]
                last_answer = len(line["text"])
        
    def get_scraper():
        scraper = create_scraper(
            browser={
                'browser': 'chrome',
                'platform': 'windows',
                'desktop': True
            },
            sess=session()
        )
        scraper.headers = {
            'Accept': '*/*',
            'Accept-Encoding': 'deflate,gzip,br',
        }
        return scraper

    def start_conversation(scraper: CloudScraper):
        response = scraper.post('https://pi.ai/api/chat/start', data="{}", headers={
            'accept': 'application/json',
            'x-api-version': '3'
        })
        if 'Just a moment' in response.text:
            raise RuntimeError('Error: Cloudflare detected')
        return Conversation(
            response.json()['conversations'][0]['sid'],
            response.cookies
        )
        
    def get_chat_history(scraper: CloudScraper, conversation: Conversation):
        params = {
            'conversation': conversation.sid,
        }
        response = scraper.get('https://pi.ai/api/chat/history', params=params, cookies=conversation.cookies)
        if 'Just a moment' in response.text:
            raise RuntimeError('Error: Cloudflare detected')
        return response.json()

    def ask(scraper: CloudScraper, messages: Messages, conversation: Conversation):
        json_data = {
            'text': format_prompt(messages),
            'conversation': conversation.sid,
            'mode': 'BASE',
        }
        response = scraper.post('https://pi.ai/api/chat', json=json_data, cookies=conversation.cookies, stream=True)
        
        for line in response.iter_lines(chunk_size=1024, decode_unicode=True):
            if 'Just a moment' in line:
                raise RuntimeError('Error: Cloudflare detected')
            if line.startswith('data: {"text":'):
               yield json.loads(line.split('data: ')[1])
            if line.startswith('data: {"title":'):
               yield json.loads(line.split('data: ')[1])
                
class Conversation():
    def __init__(self, sid: str, cookies):
        self.sid = sid
        self.cookies = cookies