diff --git a/reginaldCog/openai_client/models.py b/reginaldCog/openai_client/models.py index c17a0a9..200b98f 100644 --- a/reginaldCog/openai_client/models.py +++ b/reginaldCog/openai_client/models.py @@ -69,30 +69,38 @@ class Prompt: self.input.append(message) -def filter_fields(data: dict, filter_by_class: object) -> dict: - return {k: v for k, v in data.items() if k in tuple(f.name for f in fields(filter_by_class))} +@dataclass(frozen=True) +class Response: + output: list[Message] + + +def filter_fields(data: dict, filter_by_class: dataclass) -> dict: + return { + i_key: i_value + for i_key, i_value in data.items() + if i_key in tuple(f.name for f in fields(filter_by_class)) + } if __name__ == '__main__': from openai import OpenAI client = OpenAI() + test_system_message = Message(role='system', content=[InputText('Talk like an Italian mafia boss.')]) test_user_message = Message(role='user', content=[InputText('Hi! How are you?')]) test_prompt = Prompt(model='gpt-4.1-mini', input=[test_system_message, test_user_message]) - response = client.responses.create(**asdict(test_prompt)) - response_output = response.output[0].to_dict() - response_message = Message(**filter_fields(response_output, Message)) - print(response.to_dict().get('output')) - test_prompt.add_message(response_message) + print(asdict(test_prompt)) + response_raw = client.responses.create(**asdict(test_prompt)) + response = Response(**filter_fields(response_raw.to_dict(), Response)) + print(asdict(response)) + test_prompt.add_message(response.output[0]) + test_user_message_2 = Message(role='user', content=[ InputText('Can you tell me what is on this picture?'), InputImage(image_url='https://upload.wikimedia.org/wikipedia/commons/f/f4/Piet_Mondriaan_-_Sinaasappelen_%28authentiek%29_-_A97_-_Piet_Mondrian%2C_catalogue_raisonn%C3%A9.jpg') ]) test_prompt.add_message(test_user_message_2) - response = client.responses.create(**asdict(test_prompt)) - response_role = response.output[0].role - response_text = response.output[0].content[0].text - response_content = OutputText(response_text) - response_message = Message(role=response_role, content=[response_content]) - test_prompt.add_message(response_message) + response_raw = client.responses.create(**asdict(test_prompt)) + response = Response(**filter_fields(response_raw.to_dict(), Response)) + test_prompt.add_message(response.output[0]) print(asdict(test_prompt))