99 lines
3.0 KiB
Python

from dataclasses import dataclass, field, asdict, fields
from abc import ABC
# region Content classes
@dataclass(frozen=True)
class Content(ABC):
type: str = field(init=False, default='')
@dataclass(frozen=True)
class InputText(Content):
type: str = field(init=False, default='input_text')
text: str
@dataclass(frozen=True)
class OutputText(Content):
type: str = field(init=False, default='output_text')
text: str
@dataclass(frozen=True)
class InputImage(Content):
type: str = field(init=False, default='input_image')
image_url: str = field(default=None)
file_id: str = field(default=None)
def __post_init__(self):
if self.image_url is None and self.file_id is None:
raise ValueError('Either `image_url` or `file_id` must be provided.')
@dataclass(frozen=True)
class UrlCitation(Content):
type: str = field(init=False, default='url_citation')
# To be done
@dataclass(frozen=True)
class FunctionCall(Content):
type: str = field(init=False, default='function_call')
id: str
call_id: str
name: str
arguments: dict
@dataclass(frozen=True)
class FunctionCallOutput(Content):
type: str = field(init=False, default='function_call_output')
call_id: str
output: str
# endregion Content classes
@dataclass(frozen=True)
class Message:
role: str
content: list[Content]
@dataclass
class Prompt:
model: str
input: list[Message]
def add_message(self, message: Message):
self.input.append(message)
def filter_fields(data: dict, filter_by_class: object) -> dict:
return {k: v for k, v in data.items() if k in tuple(f.name for f in fields(filter_by_class))}
if __name__ == '__main__':
from openai import OpenAI
client = OpenAI()
test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')])
test_user_message = Message(role='user', content=[InputText('Hi! How are you?')])
test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message])
response = client.responses.create(**asdict(test_prompt))
response_output = response.to_dict().get('output', [''])[0]
response_message = Message(**filter_fields(response_output, Message))
print(response.to_dict().get('output'))
test_prompt.add_message(response_message)
test_user_message_2 = Message(role='user', content=[
InputText('Can you tell me what is on this picture?'),
InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg')
])
test_prompt.add_message(test_user_message_2)
response = client.responses.create(**asdict(test_prompt))
response_role = response.output[0].role
response_text = response.output[0].content[0].text
response_content = OutputText(response_text)
response_message = Message(role=response_role, content=[response_content])
test_prompt.add_message(response_message)
print(asdict(test_prompt))