diff --git a/reginaldCog/llm_clients/discord_client_interfaces.py b/reginaldCog/llm_clients/discord_client_interfaces.py new file mode 100644 index 0000000..1426d4f --- /dev/null +++ b/reginaldCog/llm_clients/discord_client_interfaces.py @@ -0,0 +1,7 @@ +from _common import MessageFactory +from openai_data_models import Message + + +class DiscordMessageFactory(MessageFactory): + def new(self, raw_message) -> Message: + Message() \ No newline at end of file diff --git a/reginaldCog/llm_clients/openai_data_models.py b/reginaldCog/llm_clients/openai_data_models.py index 93f01a9..f8c4585 100644 --- a/reginaldCog/llm_clients/openai_data_models.py +++ b/reginaldCog/llm_clients/openai_data_models.py @@ -1,65 +1,95 @@ from dataclasses import dataclass, field, asdict -from datetime import datetime +from abc import ABC -# region dataclasses +# region Content classes @dataclass(frozen=True) -class Attachment: - content_type: str - filename: str - id: int - size: int - url: str - ephemeral: bool = False +class Content(ABC): + type: str = field(init=False, default='') @dataclass(frozen=True) -class Channel: - created_at: datetime - id: int +class InputText(Content): + type: str = field(init=False, default='input_text') + text: str + + +@dataclass(frozen=True) +class OutputText(Content): + type: str = field(init=False, default='output_text') + text: str + + +@dataclass(frozen=True) +class InputImage(Content): + type: str = field(init=False, default='input_image') + image_url: str = field(default=None) + file_id: str = field(default=None) + + def __post_init__(self): + if self.image_url is None and self.file_id is None: + raise ValueError('Either `image_url` or `file_id` must be provided.') + + +@dataclass(frozen=True) +class UrlCitation(Content): + type: str = field(init=False, default='url_citation') + # To be done + + +@dataclass(frozen=True) +class FunctionCall(Content): + type: str = field(init=False, default='function_call') + id: str + call_id: str name: str - - def __str__(self) -> str: - return self.name + arguments: dict @dataclass(frozen=True) -class Role: - id: int - name: str - - def __str__(self) -> str: - return self.name - - -@dataclass(frozen=True) -class Member: - bot: bool - created_at: datetime - display_name: str # For regular users this is just their global name or their username, but if they have a guild specific nickname then that is returned instead. - global_name: str # The user’s global nickname, taking precedence over the username in display. - id: int - joined_at: datetime - mention: str - name: str # The user’s username. - nick: str # The guild specific nickname of the user. Takes precedence over the global name. - roles: list[Role] = field(default_factory=list) - - def __str__(self) -> str: - return self.name +class FunctionCallOutput(Content): + type: str = field(init=False, default='function_call_output') + call_id: str + output: str +# endregion Content classes @dataclass(frozen=True) class Message: - author: Member - channel: Channel - created_at: datetime - id: int - attachments: list[Attachment] = field(default_factory=list) - channel_mentions: list[Channel] = field(default_factory=list) - content: str = '' - mentions: list[Member] = field(default_factory=list) + role: str + content: list[Content] + + +@dataclass(frozen=True) +class Prompt: + model: str + input: list[Message] + + +@dataclass(frozen=True) +class Response: + output: list[Message] if __name__ == '__main__': - pass + from openai import OpenAI + client = OpenAI() + + test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')]) + test_user_message = Message(role='user', content=[InputText('Hi! How are you?')]) + test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message]) + print(asdict(test_prompt)) + response_raw = client.responses.create(**asdict(test_prompt)) + response = Response(**filter_fields(response_raw.to_dict(), Response)) + print(asdict(response)) + test_prompt.add_message(response.output[0]) + + test_user_message_2 = Message(role='user', content=[ + InputText('Can you tell me what is on this picture?'), + InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg') + ]) + test_prompt.add_message(test_user_message_2) + response_raw = client.responses.create(**asdict(test_prompt)) + response = Response(**filter_fields(response_raw.to_dict(), Response)) + test_prompt.add_message(response.output[0]) + print(asdict(test_prompt)) diff --git a/reginaldCog/messenger_clients/_common.py b/reginaldCog/messenger_clients/_common.py index fb5996c..6dcaff2 100644 --- a/reginaldCog/messenger_clients/_common.py +++ b/reginaldCog/messenger_clients/_common.py @@ -3,5 +3,5 @@ from abc import ABC, abstractmethod class MessageFactory(ABC): @abstractmethod - def new(self, raw_message: dict) -> object: + def new(self, raw_message) -> object: pass diff --git a/reginaldCog/messenger_clients/discord_client_interfaces.py b/reginaldCog/messenger_clients/discord_client_interfaces.py deleted file mode 100644 index e69de29..0000000 diff --git a/reginaldCog/messenger_clients/discord_data_models.py b/reginaldCog/messenger_clients/discord_data_models.py index f8c4585..93f01a9 100644 --- a/reginaldCog/messenger_clients/discord_data_models.py +++ b/reginaldCog/messenger_clients/discord_data_models.py @@ -1,95 +1,65 @@ from dataclasses import dataclass, field, asdict -from abc import ABC +from datetime import datetime -# region Content classes +# region dataclasses @dataclass(frozen=True) -class Content(ABC): - type: str = field(init=False, default='') +class Attachment: + content_type: str + filename: str + id: int + size: int + url: str + ephemeral: bool = False @dataclass(frozen=True) -class InputText(Content): - type: str = field(init=False, default='input_text') - text: str - - -@dataclass(frozen=True) -class OutputText(Content): - type: str = field(init=False, default='output_text') - text: str - - -@dataclass(frozen=True) -class InputImage(Content): - type: str = field(init=False, default='input_image') - image_url: str = field(default=None) - file_id: str = field(default=None) - - def __post_init__(self): - if self.image_url is None and self.file_id is None: - raise ValueError('Either `image_url` or `file_id` must be provided.') - - -@dataclass(frozen=True) -class UrlCitation(Content): - type: str = field(init=False, default='url_citation') - # To be done - - -@dataclass(frozen=True) -class FunctionCall(Content): - type: str = field(init=False, default='function_call') - id: str - call_id: str +class Channel: + created_at: datetime + id: int name: str - arguments: dict + + def __str__(self) -> str: + return self.name @dataclass(frozen=True) -class FunctionCallOutput(Content): - type: str = field(init=False, default='function_call_output') - call_id: str - output: str -# endregion Content classes +class Role: + id: int + name: str + + def __str__(self) -> str: + return self.name + + +@dataclass(frozen=True) +class Member: + bot: bool + created_at: datetime + display_name: str # For regular users this is just their global name or their username, but if they have a guild specific nickname then that is returned instead. + global_name: str # The user’s global nickname, taking precedence over the username in display. + id: int + joined_at: datetime + mention: str + name: str # The user’s username. + nick: str # The guild specific nickname of the user. Takes precedence over the global name. + roles: list[Role] = field(default_factory=list) + + def __str__(self) -> str: + return self.name @dataclass(frozen=True) class Message: - role: str - content: list[Content] - - -@dataclass(frozen=True) -class Prompt: - model: str - input: list[Message] - - -@dataclass(frozen=True) -class Response: - output: list[Message] + author: Member + channel: Channel + created_at: datetime + id: int + attachments: list[Attachment] = field(default_factory=list) + channel_mentions: list[Channel] = field(default_factory=list) + content: str = '' + mentions: list[Member] = field(default_factory=list) if __name__ == '__main__': - from openai import OpenAI - client = OpenAI() - - test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')]) - test_user_message = Message(role='user', content=[InputText('Hi! How are you?')]) - test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message]) - print(asdict(test_prompt)) - response_raw = client.responses.create(**asdict(test_prompt)) - response = Response(**filter_fields(response_raw.to_dict(), Response)) - print(asdict(response)) - test_prompt.add_message(response.output[0]) - - test_user_message_2 = Message(role='user', content=[ - InputText('Can you tell me what is on this picture?'), - InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg') - ]) - test_prompt.add_message(test_user_message_2) - response_raw = client.responses.create(**asdict(test_prompt)) - response = Response(**filter_fields(response_raw.to_dict(), Response)) - test_prompt.add_message(response.output[0]) - print(asdict(test_prompt)) + pass