File structure refactor
This commit is contained in:
parent
180c167a43
commit
ff70196756
7
reginaldCog/llm_clients/discord_client_interfaces.py
Normal file
7
reginaldCog/llm_clients/discord_client_interfaces.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from _common import MessageFactory
|
||||||
|
from openai_data_models import Message
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordMessageFactory(MessageFactory):
|
||||||
|
def new(self, raw_message) -> Message:
|
||||||
|
Message()
|
||||||
@ -1,65 +1,95 @@
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from datetime import datetime
|
from abc import ABC
|
||||||
|
|
||||||
|
|
||||||
# region dataclasses
|
# region Content classes
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Attachment:
|
class Content(ABC):
|
||||||
content_type: str
|
type: str = field(init=False, default='')
|
||||||
filename: str
|
|
||||||
id: int
|
|
||||||
size: int
|
|
||||||
url: str
|
|
||||||
ephemeral: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Channel:
|
class InputText(Content):
|
||||||
created_at: datetime
|
type: str = field(init=False, default='input_text')
|
||||||
id: int
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class OutputText(Content):
|
||||||
|
type: str = field(init=False, default='output_text')
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class InputImage(Content):
|
||||||
|
type: str = field(init=False, default='input_image')
|
||||||
|
image_url: str = field(default=None)
|
||||||
|
file_id: str = field(default=None)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.image_url is None and self.file_id is None:
|
||||||
|
raise ValueError('Either `image_url` or `file_id` must be provided.')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class UrlCitation(Content):
|
||||||
|
type: str = field(init=False, default='url_citation')
|
||||||
|
# To be done
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class FunctionCall(Content):
|
||||||
|
type: str = field(init=False, default='function_call')
|
||||||
|
id: str
|
||||||
|
call_id: str
|
||||||
name: str
|
name: str
|
||||||
|
arguments: dict
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Role:
|
class FunctionCallOutput(Content):
|
||||||
id: int
|
type: str = field(init=False, default='function_call_output')
|
||||||
name: str
|
call_id: str
|
||||||
|
output: str
|
||||||
def __str__(self) -> str:
|
# endregion Content classes
|
||||||
return self.name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Member:
|
|
||||||
bot: bool
|
|
||||||
created_at: datetime
|
|
||||||
display_name: str # For regular users this is just their global name or their username, but if they have a guild specific nickname then that is returned instead.
|
|
||||||
global_name: str # The user’s global nickname, taking precedence over the username in display.
|
|
||||||
id: int
|
|
||||||
joined_at: datetime
|
|
||||||
mention: str
|
|
||||||
name: str # The user’s username.
|
|
||||||
nick: str # The guild specific nickname of the user. Takes precedence over the global name.
|
|
||||||
roles: list[Role] = field(default_factory=list)
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
return self.name
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Message:
|
class Message:
|
||||||
author: Member
|
role: str
|
||||||
channel: Channel
|
content: list[Content]
|
||||||
created_at: datetime
|
|
||||||
id: int
|
|
||||||
attachments: list[Attachment] = field(default_factory=list)
|
@dataclass(frozen=True)
|
||||||
channel_mentions: list[Channel] = field(default_factory=list)
|
class Prompt:
|
||||||
content: str = ''
|
model: str
|
||||||
mentions: list[Member] = field(default_factory=list)
|
input: list[Message]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Response:
|
||||||
|
output: list[Message]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pass
|
from openai import OpenAI
|
||||||
|
client = OpenAI()
|
||||||
|
|
||||||
|
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
|
||||||
|
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
|
||||||
|
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
|
||||||
|
print(asdict(test_prompt))
|
||||||
|
response_raw = client.responses.create(**asdict(test_prompt))
|
||||||
|
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||||
|
print(asdict(response))
|
||||||
|
test_prompt.add_message(response.output[0])
|
||||||
|
|
||||||
|
test_user_message_2 = Message(role='user', content=[
|
||||||
|
InputText('Can you tell me what is on this picture?'),
|
||||||
|
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
|
||||||
|
])
|
||||||
|
test_prompt.add_message(test_user_message_2)
|
||||||
|
response_raw = client.responses.create(**asdict(test_prompt))
|
||||||
|
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
||||||
|
test_prompt.add_message(response.output[0])
|
||||||
|
print(asdict(test_prompt))
|
||||||
|
|||||||
@ -3,5 +3,5 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
class MessageFactory(ABC):
|
class MessageFactory(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def new(self, raw_message: dict) -> object:
|
def new(self, raw_message) -> object:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@ -1,95 +1,65 @@
|
|||||||
from dataclasses import dataclass, field, asdict
|
from dataclasses import dataclass, field, asdict
|
||||||
from abc import ABC
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
# region Content classes
|
# region dataclasses
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Content(ABC):
|
class Attachment:
|
||||||
type: str = field(init=False, default='')
|
content_type: str
|
||||||
|
filename: str
|
||||||
|
id: int
|
||||||
|
size: int
|
||||||
|
url: str
|
||||||
|
ephemeral: bool = False
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class InputText(Content):
|
class Channel:
|
||||||
type: str = field(init=False, default='input_text')
|
created_at: datetime
|
||||||
text: str
|
id: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class OutputText(Content):
|
|
||||||
type: str = field(init=False, default='output_text')
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class InputImage(Content):
|
|
||||||
type: str = field(init=False, default='input_image')
|
|
||||||
image_url: str = field(default=None)
|
|
||||||
file_id: str = field(default=None)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.image_url is None and self.file_id is None:
|
|
||||||
raise ValueError('Either `image_url` or `file_id` must be provided.')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class UrlCitation(Content):
|
|
||||||
type: str = field(init=False, default='url_citation')
|
|
||||||
# To be done
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class FunctionCall(Content):
|
|
||||||
type: str = field(init=False, default='function_call')
|
|
||||||
id: str
|
|
||||||
call_id: str
|
|
||||||
name: str
|
name: str
|
||||||
arguments: dict
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class FunctionCallOutput(Content):
|
class Role:
|
||||||
type: str = field(init=False, default='function_call_output')
|
id: int
|
||||||
call_id: str
|
name: str
|
||||||
output: str
|
|
||||||
# endregion Content classes
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Member:
|
||||||
|
bot: bool
|
||||||
|
created_at: datetime
|
||||||
|
display_name: str # For regular users this is just their global name or their username, but if they have a guild specific nickname then that is returned instead.
|
||||||
|
global_name: str # The user’s global nickname, taking precedence over the username in display.
|
||||||
|
id: int
|
||||||
|
joined_at: datetime
|
||||||
|
mention: str
|
||||||
|
name: str # The user’s username.
|
||||||
|
nick: str # The guild specific nickname of the user. Takes precedence over the global name.
|
||||||
|
roles: list[Role] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
return self.name
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class Message:
|
class Message:
|
||||||
role: str
|
author: Member
|
||||||
content: list[Content]
|
channel: Channel
|
||||||
|
created_at: datetime
|
||||||
|
id: int
|
||||||
@dataclass(frozen=True)
|
attachments: list[Attachment] = field(default_factory=list)
|
||||||
class Prompt:
|
channel_mentions: list[Channel] = field(default_factory=list)
|
||||||
model: str
|
content: str = ''
|
||||||
input: list[Message]
|
mentions: list[Member] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class Response:
|
|
||||||
output: list[Message]
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from openai import OpenAI
|
pass
|
||||||
client = OpenAI()
|
|
||||||
|
|
||||||
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
|
|
||||||
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
|
|
||||||
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
|
|
||||||
print(asdict(test_prompt))
|
|
||||||
response_raw = client.responses.create(**asdict(test_prompt))
|
|
||||||
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
|
||||||
print(asdict(response))
|
|
||||||
test_prompt.add_message(response.output[0])
|
|
||||||
|
|
||||||
test_user_message_2 = Message(role='user', content=[
|
|
||||||
InputText('Can you tell me what is on this picture?'),
|
|
||||||
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
|
|
||||||
])
|
|
||||||
test_prompt.add_message(test_user_message_2)
|
|
||||||
response_raw = client.responses.create(**asdict(test_prompt))
|
|
||||||
response = Response(**filter_fields(response_raw.to_dict(), Response))
|
|
||||||
test_prompt.add_message(response.output[0])
|
|
||||||
print(asdict(test_prompt))
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user