From 326b3919d183d6630f5a85466fc9d7eb48873e58 Mon Sep 17 00:00:00 2001 From: Andre Cipriani Bandarra Date: Tue, 26 Nov 2024 20:27:36 +0000 Subject: [PATCH] Add builders for GenerateContent and CountToken --- examples/count-tokens.rs | 7 +--- examples/google-search-retrieval.rs | 2 +- examples/json-schema.rs | 2 +- examples/safety-setting.rs | 2 +- examples/system_instruction.rs | 2 +- examples/text-from-text.rs | 5 +-- src/client.rs | 9 ++--- src/dialogue.rs | 31 +---------------- src/types/common.rs | 34 +++++++++++++++++-- src/types/count_tokens.rs | 33 ++++++++++++++++++ src/types/generate_content.rs | 52 +++++++++++++++++++++++++++-- 11 files changed, 128 insertions(+), 51 deletions(-) diff --git a/examples/count-tokens.rs b/examples/count-tokens.rs index 7541a46..8a4b102 100644 --- a/examples/count-tokens.rs +++ b/examples/count-tokens.rs @@ -15,12 +15,7 @@ async fn main() -> Result<(), Box> { ); let prompt = "What is the airspeed of an unladen swallow?"; - let request = CountTokensRequest { - contents: Content { - role: Some("user".to_string()), - parts: Some(vec![Part::Text(prompt.to_string())]), - }, - }; + let request = CountTokensRequestBuilder::from_prompt(prompt).build(); let result = gemini.count_tokens(&request, "gemini-pro").await?; println!("Response: {:?}", result); diff --git a/examples/google-search-retrieval.rs b/examples/google-search-retrieval.rs index 275c15c..d209096 100644 --- a/examples/google-search-retrieval.rs +++ b/examples/google-search-retrieval.rs @@ -19,7 +19,7 @@ async fn main() -> Result<(), Box> { let request = GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], tools: Some(vec![Tools { diff --git a/examples/json-schema.rs b/examples/json-schema.rs index 748633a..c56cf23 100644 --- a/examples/json-schema.rs +++ b/examples/json-schema.rs @@ -18,7 +18,7 @@ async fn main() -> Result<(), Box> { let prompt = "Generate 10 ideas of blog posts with a title and decription for each idea."; let request = GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config: Some(GenerationConfig { diff --git a/examples/safety-setting.rs b/examples/safety-setting.rs index f321265..a4840d0 100644 --- a/examples/safety-setting.rs +++ b/examples/safety-setting.rs @@ -20,7 +20,7 @@ async fn main() -> Result<(), Box> { let request = GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], safety_settings: Some(vec![SafetySetting { diff --git a/examples/system_instruction.rs b/examples/system_instruction.rs index 4def033..8330f33 100644 --- a/examples/system_instruction.rs +++ b/examples/system_instruction.rs @@ -20,7 +20,7 @@ async fn main() -> Result<(), Box> { let request = GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], system_instruction: Some(Content { diff --git a/examples/text-from-text.rs b/examples/text-from-text.rs index f69790f..06d5f9d 100644 --- a/examples/text-from-text.rs +++ b/examples/text-from-text.rs @@ -15,8 +15,9 @@ async fn main() -> Result<(), Box> { ); let prompt = "What is the airspeed of an unladen swallow?"; - let result = gemini.prompt_text(prompt, None).await?; - println!("Response: {}", result); + let request = GenerateContentRequest::builder().with_prompt(prompt).build(); + let response = gemini.generate_content(&request, "gemini-pro").await?; + println!("Response: {:?}", response.candidates[0].get_text().unwrap()); Ok(()) } diff --git a/src/client.rs b/src/client.rs index e36ba0b..c54ded3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,14 +5,14 @@ use futures_util::stream::StreamExt; use reqwest_eventsource::{Event, EventSource}; use tracing::error; -use crate::dialogue::{Message, Role}; +use crate::dialogue::Message; use crate::error::{Error, Result}; use crate::prelude::{ Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse, }; -use crate::types::{PredictImageRequest, PredictImageResponse}; +use crate::types::{PredictImageRequest, PredictImageResponse, Role}; use crate::{prelude::Part, token_provider::TokenProvider}; pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; @@ -154,7 +154,7 @@ impl GeminiClient { contents: messages .iter() .map(|m| Content { - role: Some(m.role.to_string()), + role: Some(m.role), parts: Some(vec![Part::Text(m.text.clone())]), }) .collect(), @@ -177,6 +177,7 @@ impl GeminiClient { /// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text /// from the response. + #[deprecated(note = "Use `generate_content` instead")] pub async fn prompt_text( &self, prompt: &str, @@ -184,7 +185,7 @@ impl GeminiClient { ) -> Result { let request = GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config: generation_config.cloned(), diff --git a/src/dialogue.rs b/src/dialogue.rs index 56d81b9..e819494 100644 --- a/src/dialogue.rs +++ b/src/dialogue.rs @@ -1,35 +1,6 @@ -use std::str::FromStr; - use serde::{Deserialize, Serialize}; -use crate::{client::GeminiClient, error::Result, prelude::TokenProvider}; - -#[derive(Clone, Debug, Serialize, Deserialize)] -pub enum Role { - User, - Model, -} - -impl ToString for Role { - fn to_string(&self) -> String { - match self { - Role::User => "user".to_string(), - Role::Model => "model".to_string(), - } - } -} - -impl FromStr for Role { - type Err = (); - - fn from_str(s: &str) -> std::result::Result { - match s { - "user" => Ok(Role::User), - "model" => Ok(Role::Model), - _ => Err(()), - } - } -} +use crate::{client::GeminiClient, error::Result, prelude::TokenProvider, types::Role}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct Message { diff --git a/src/types/common.rs b/src/types/common.rs index 7a4d0dc..716e458 100644 --- a/src/types/common.rs +++ b/src/types/common.rs @@ -1,10 +1,10 @@ -use std::collections::HashMap; +use std::{collections::HashMap, str::FromStr}; use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, Serialize, Deserialize)] +#[derive(Clone, Default, Debug, Serialize, Deserialize)] pub struct Content { - pub role: Option, + pub role: Option, pub parts: Option>, } @@ -22,6 +22,34 @@ impl Content { } } +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Model, +} + +impl ToString for Role { + fn to_string(&self) -> String { + match self { + Role::User => "user".to_string(), + Role::Model => "model".to_string(), + } + } +} + +impl FromStr for Role { + type Err = (); + + fn from_str(s: &str) -> std::result::Result { + match s { + "user" => Ok(Role::User), + "model" => Ok(Role::Model), + _ => Err(()), + } + } +} + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum Part { diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs index cce0612..a152055 100644 --- a/src/types/count_tokens.rs +++ b/src/types/count_tokens.rs @@ -7,6 +7,39 @@ pub struct CountTokensRequest { pub contents: Content, } +impl CountTokensRequest { + pub fn builder() -> CountTokensRequestBuilder { + CountTokensRequestBuilder::new() + } +} + +pub struct CountTokensRequestBuilder { + contents: Content, +} + +impl CountTokensRequestBuilder { + pub fn new() -> Self { + CountTokensRequestBuilder { + contents: Content::default(), + } + } + + pub fn from_prompt(prompt: &str) -> Self { + CountTokensRequestBuilder { + contents: Content { + parts: Some(vec![super::Part::Text(prompt.to_string())]), + ..Default::default() + }, + } + } + + pub fn build(self) -> CountTokensRequest { + CountTokensRequest { + contents: self.contents, + } + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(untagged)] pub enum CountTokensResponse { diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs index 99c2df8..e905104 100644 --- a/src/types/generate_content.rs +++ b/src/types/generate_content.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::{Content, Part, VertexApiError}; +use super::{Content, Part, Role, VertexApiError}; use crate::error::Result; #[derive(Clone, Default, Serialize, Deserialize)] @@ -24,7 +24,7 @@ impl GenerateContentRequest { pub fn from_prompt(prompt: &str, generation_config: Option) -> Self { GenerateContentRequest { contents: vec![Content { - role: Some("user".to_string()), + role: Some(Role::User), parts: Some(vec![Part::Text(prompt.to_string())]), }], generation_config, @@ -33,6 +33,54 @@ impl GenerateContentRequest { safety_settings: None, } } + + pub fn builder() -> GenerateContentRequestBuilder { + GenerateContentRequestBuilder::new() + } +} + +pub struct GenerateContentRequestBuilder { + request: GenerateContentRequest, +} + +impl GenerateContentRequestBuilder { + fn new() -> Self { + GenerateContentRequestBuilder { + request: GenerateContentRequest::default(), + } + } + + pub fn with_prompt(mut self, prompt: &str) -> Self { + self.request.contents = vec![Content { + role: Some(Role::User), + parts: Some(vec![Part::Text(prompt.to_string())]), + }]; + self + } + + pub fn with_generation_config(mut self, generation_config: GenerationConfig) -> Self { + self.request.generation_config = Some(generation_config); + self + } + + pub fn with_tools(mut self, tools: Vec) -> Self { + self.request.tools = Some(tools); + self + } + + pub fn with_safety_settings(mut self, safety_settings: Vec) -> Self { + self.request.safety_settings = Some(safety_settings); + self + } + + pub fn with_system_instruction(mut self, system_instruction: Content) -> Self { + self.request.system_instruction = Some(system_instruction); + self + } + + pub fn build(self) -> GenerateContentRequest { + self.request + } } #[derive(Clone, Default, Serialize, Deserialize)]