diff --git a/examples/text-from-text-streaming.rs b/examples/text-from-text-streaming.rs index adf0edb..578a729 100644 --- a/examples/text-from-text-streaming.rs +++ b/examples/text-from-text-streaming.rs @@ -15,7 +15,15 @@ async fn main() -> Result<(), Box> { ); let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; - let request = GenerateContentRequest::from_prompt(prompt, None); + let request = GenerateContentRequest::builder() + .add_content( + Content::builder() + .role(Role::User) + .add_part(Part::Text(prompt.to_string())) + .build(), + ) + .build(); + let queue = gemini.stream_generate_content(&request, "gemini-pro").await; while let Some(response) = queue.pop().await { diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs index 06d5f9d..9b7cfe5 100644 --- a/examples/text-from-text.rs +++ b/examples/text-from-text.rs @@ -15,7 +15,14 @@ async fn main() -> Result<(), Box> { ); let prompt = "What is the airspeed of an unladen swallow?"; - let request = GenerateContentRequest::builder().with_prompt(prompt).build(); + let request = GenerateContentRequest::builder() + .add_content( + Content::builder() + .role(Role::User) + .add_part(Part::Text(prompt.to_string())) + .build(), + ) + .build(); let response = gemini.generate_content(&request, "gemini-pro").await?; println!("Response: {:?}", response.candidates[0].get_text().unwrap()); diff --git a/src/types/common.rs b/src/types/common.rs index 716e458..5afbf6f 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, str::FromStr}; +use std::{collections::HashMap, str::FromStr, vec}; use serde::{Deserialize, Serialize}; @@ -20,6 +20,39 @@ impl Content { .collect::() }) } + + pub fn builder() -> ContentBuilder { + ContentBuilder::new() + } +} + +pub struct ContentBuilder { + content: Content, +} + +impl ContentBuilder { + pub fn new() -> Self { + Self { + content: Default::default(), + } + } + + pub fn add_part(mut self, part: Part) -> Self { + match &mut self.content.parts { + Some(parts) => parts.push(part), + None => self.content.parts = Some(vec![part]), + } + self + } + + pub fn role(mut self, role: Role) -> Self { + self.content.role = Some(role); + self + } + + pub fn build(self) -> Content { + self.content + } } #[derive(Clone, Copy, Debug, Serialize, Deserialize)] diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index e905104..c5204b3 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::{Content, Part, Role, VertexApiError}; +use super::{Content, VertexApiError}; use crate::error::Result; #[derive(Clone, Default, Serialize, Deserialize)] @@ -21,19 +21,6 @@ pub struct GenerateContentRequest { } impl GenerateContentRequest { - pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { - GenerateContentRequest { - contents: vec![Content { - role: Some(Role::User), - parts: Some(vec![Part::Text(prompt.to_string())]), - }], - generation_config, - tools: None, - system_instruction: None, - safety_settings: None, - } - } - pub fn builder() -> GenerateContentRequestBuilder { GenerateContentRequestBuilder::new() } @@ -50,30 +37,27 @@ impl GenerateContentRequestBuilder { } } - pub fn with_prompt(mut self, prompt: &str) -> Self { - self.request.contents = vec![Content { - role: Some(Role::User), - parts: Some(vec![Part::Text(prompt.to_string())]), - }]; + pub fn add_content(mut self, content: Content) -> Self { + self.request.contents.push(content); self } - pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self { + pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self { self.request.generation_config = Some(generation_config); self } - pub fn with_tools(mut self, tools: Vec) -> Self { + pub fn tools(mut self, tools: Vec) -> Self { self.request.tools = Some(tools); self } - pub fn with_safety_settings(mut self, safety_settings: Vec) -> Self { + pub fn safety_settings(mut self, safety_settings: Vec) -> Self { self.request.safety_settings = Some(safety_settings); self } - pub fn with_system_instruction(mut self, system_instruction: Content) -> Self { + pub fn system_instruction(mut self, system_instruction: Content) -> Self { self.request.system_instruction = Some(system_instruction); self } @@ -119,6 +103,68 @@ pub struct GenerationConfig { pub response_schema: Option, } +impl GenerationConfig { + pub fn builder() -> GenerationConfigBuilder { + GenerationConfigBuilder::new() + } +} + +pub struct GenerationConfigBuilder { + generation_config: GenerationConfig, +} + +impl GenerationConfigBuilder { + fn new() -> Self { + Self { + generation_config: Default::default(), + } + } + + pub fn max_output_tokens>(mut self, max_output_tokens: T) -> Self { + self.generation_config.max_output_tokens = Some(max_output_tokens.into()); + self + } + + pub fn temperature>(mut self, temperature: T) -> Self { + self.generation_config.temperature = Some(temperature.into()); + self + } + + pub fn top_p>(mut self, top_p: T) -> Self { + self.generation_config.top_p = Some(top_p.into()); + self + } + + pub fn top_k>(mut self, top_k: T) -> Self { + self.generation_config.top_k = Some(top_k.into()); + self + } + + pub fn stop_sequences>>(mut self, stop_sequences: T) -> Self { + self.generation_config.stop_sequences = Some(stop_sequences.into()); + self + } + + pub fn candidate_count>(mut self, candidate_count: T) -> Self { + self.generation_config.candidate_count = Some(candidate_count.into()); + self + } + + pub fn response_mime_type>(mut self, response_mime_type: T) -> Self { + self.generation_config.response_mime_type = Some(response_mime_type.into()); + self + } + + pub fn response_schema>(mut self, response_schema: T) -> Self { + self.generation_config.response_schema = Some(response_schema.into()); + self + } + + pub fn build(self) -> GenerationConfig { + self.generation_config + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct SafetySetting {