commit 6786a5e9f0aa6fc62efcfa6e0c13474c4cc93460 Author: Andre Bandarra Date: Wed Nov 26 17:51:15 2025 +0000 Initial Commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4c49bd7 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.env diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..17b1176 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "google-genai" +version = "0.1.0" +edition = "2024" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +deadqueue = "0.2" +reqwest = { version = "0.12", features = ["json", "gzip"] } +reqwest-eventsource = "0.6" +serde = { version = "1", features = ["derive"] } +serde_json = { version = "1"} +serde_with = { version = "3.9", features = ["base64"]} +tracing = "0.1" +tokio = { version = "1" } +tokio-stream = "0.1.17" + +[dev-dependencies] +console = "0.16.0" +dialoguer = "0.12.0" +dotenvy = "0.15.7" +image = "0.25.2" +indicatif = "0.18.0" +tokio = { version = "1.47.0", features = ["full"] } +tracing-subscriber = "0.3.20" + diff --git a/examples/function-call.rs b/examples/function-call.rs new file mode 100644 index 0000000..288a217 --- /dev/null +++ b/examples/function-call.rs @@ -0,0 +1,113 @@ +use std::{env, error::Error}; + +use google_genai::prelude::{ + Content, FunctionDeclaration, GeminiClient, GenerateContentRequest, Part, PartData, Role, Tools, +}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let _ = dotenvy::dotenv(); + + let api_key = env::var("GEMINI_API_KEY")?; + let gemini_client = GeminiClient::new(api_key); + + let mut contents = vec![ + Content::builder() + .role(Role::User) + .add_text_part("What is the sbrubbles value of 213 and 231?") + .build(), + ]; + + let sum_parameters = json!({ + "type": "object", + "properties": { + "left": { + "type": "integer", + "description": "The first value for the sbrubbles calculation" + }, + "right": { + "type": "integer", + "description": "The second value for the sbrubbles calculation" + } + }, + "required": ["left", "right"] + }); + + let sum_result = json!({ + "type": "object", + "properties": { + "result": { + "type": "integer", + "description": "The sbrubbles value calculation result", + } + } + }); + + let sum_function = FunctionDeclaration { + name: String::from("sbrubbles"), + description: String::from("Calculates the sbrubbles value"), + parameters: None, + parameters_json_schema: Some(sum_parameters), + response: None, + response_json_schema: Some(sum_result), + }; + + println!("{}", serde_json::to_string_pretty(&sum_function).unwrap()); + + let tools = Tools { + function_declarations: Some(vec![sum_function]), + ..Default::default() + }; + + let request = GenerateContentRequest::builder() + .contents(contents.clone()) + .tools(vec![tools.clone()]) + .build(); + + let mut response = gemini_client + .generate_content(&request, "gemini-3-pro-preview") + .await?; + + while let Some(candidate) = response.candidates.last() + && let Some(content) = &candidate.content + && let Some(parts) = &content.parts + && let Some(part) = parts.last() + && let PartData::FunctionCall { id, name, args, .. } = &part.data + { + contents.push(content.clone()); + match args { + Some(args) => println!("Function call: {name}, {args}"), + None => println!("Function call: {name}"), + } + + contents.push(Content { + role: Some(Role::User), + parts: Some(vec![Part { + data: PartData::FunctionResponse { + id: id.clone(), + name: name.clone(), + response: json!({"result": 1234}), + will_continue: None, + }, + media_resolution: None, + part_metadata: None, + thought: None, + thought_signature: part.thought_signature.clone(), + }]), + }); + let request = GenerateContentRequest::builder() + .contents(contents.clone()) + .tools(vec![tools.clone()]) + .build(); + + println!("{contents:?}"); + response = gemini_client + .generate_content(&request, "gemini-3-pro-preview") + .await?; + } + + println!("Response: {:#?}", response.candidates); + Ok(()) +} diff --git a/examples/generate-content.rs b/examples/generate-content.rs new file mode 100644 index 0000000..2b7f775 --- /dev/null +++ b/examples/generate-content.rs @@ -0,0 +1,26 @@ +use std::{env, error::Error}; + +use google_genai::prelude::{Content, GeminiClient, GenerateContentRequest, Role}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let _ = dotenvy::dotenv(); + + let api_key = env::var("GEMINI_API_KEY")?; + let gemini_client = GeminiClient::new(api_key); + let prompt = vec![ + Content::builder() + .role(Role::User) + .add_text_part("What is the airspeed of an unladen swallow?") + .build(), + ]; + + let request = GenerateContentRequest::builder().contents(prompt).build(); + let response = gemini_client + .generate_content(&request, "gemini-3-pro-preview") + .await?; + + println!("Response: {:?}", response.candidates[0].get_text().unwrap()); + Ok(()) +} diff --git a/examples/generate-image.rs b/examples/generate-image.rs new file mode 100644 index 0000000..854c6d0 --- /dev/null +++ b/examples/generate-image.rs @@ -0,0 +1,50 @@ +use std::{env, error::Error, io::Cursor}; + +use google_genai::prelude::{ + GeminiClient, PersonGeneration, PredictImageRequest, PredictImageRequestParameters, + PredictImageRequestParametersOutputOptions, PredictImageRequestPrompt, + PredictImageSafetySetting, +}; +use image::{ImageFormat, ImageReader}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + tracing_subscriber::fmt().init(); + let _ = dotenvy::dotenv(); + + let api_key = env::var("GEMINI_API_KEY")?; + let gemini_client = GeminiClient::new(api_key); + + let prompt = " + Create an image of a tuxedo cat riding a rocket to the moon."; + let request = PredictImageRequest { + instances: vec![PredictImageRequestPrompt { + prompt: prompt.to_string(), + }], + parameters: PredictImageRequestParameters { + sample_count: 1, + aspect_ratio: Some("1:1".to_string()), + output_options: Some(PredictImageRequestParametersOutputOptions { + mime_type: Some("image/jpeg".to_string()), + compression_quality: Some(75), + }), + person_generation: Some(PersonGeneration::AllowAdult), + safety_setting: Some(PredictImageSafetySetting::BlockLowAndAbove), + ..Default::default() + }, + }; + + println!("Request: {:#?}", serde_json::to_string(&request).unwrap()); + + let mut result = gemini_client + .predict_image(&request, "imagen-4.0-generate-001") + .await?; + + let result = result.predictions.pop().unwrap(); + + let format = ImageFormat::from_mime_type(result.mime_type).unwrap(); + let img = + ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?; + img.save("output.jpg")?; + Ok(()) +} diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..da86308 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,309 @@ +use crate::error::Result as GeminiResult; +use std::sync::Arc; +use std::vec; +use tokio_stream::{Stream, StreamExt}; + +use deadqueue::unlimited::Queue; +use reqwest_eventsource::{Event, EventSource}; +use tracing::error; + +use crate::dialogue::Message; +use crate::error::{Error, Result}; +use crate::prelude::Part; +use crate::prelude::{ + Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, + GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest, + TextEmbeddingResponse, +}; +use crate::types::{PredictImageRequest, PredictImageResponse, Role}; + +pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"]; + +#[derive(Clone, Debug)] +pub struct GeminiClient { + client: reqwest::Client, + api_key: String, +} + +unsafe impl Send for GeminiClient {} +unsafe impl Sync for GeminiClient {} + +impl GeminiClient { + pub fn new(api_key: String) -> Self { + GeminiClient { + client: reqwest::Client::new(), + api_key, + } + } + + pub async fn generate_content_stream( + &self, + request: &GenerateContentRequest, + model: &str, + ) -> Result>> { + let endpoint_url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse" + ); + let client = self.client.clone(); + let request = request.clone(); + let req = client + .post(&endpoint_url) + .header("x-goog-api-key", &self.api_key) + .json(&request); + + let event_source = EventSource::new(req).unwrap(); + + let mapped = event_source.filter_map(|event| { + let event = match event { + Ok(event) => event, + Err(reqwest_eventsource::Error::StreamEnded) => { + return Some(Err(Error::EventSourceClosedError)); + } + Err(e) => return Some(Err(e.into())), + }; + + let Event::Message(event_message) = event else { + return None; + }; + + let gemini_response: GenerateContentResponse = + match serde_json::from_str(&event_message.data) { + Ok(gemini_response) => gemini_response, + Err(e) => return Some(Err(e.into())), + }; + + let gemini_response = match gemini_response.into_result() { + Ok(gemini_response) => gemini_response, + Err(e) => return Some(Err(e)), + }; + + Some(Ok(gemini_response)) + }); + Ok(mapped) + } + + pub async fn stream_generate_content( + &self, + request: &GenerateContentRequest, + model: &str, + ) -> Arc>>> { + let queue = Arc::new(Queue::>>::new()); + + // Clone the queue and other necessary data to move into the async block. + let cloned_queue = queue.clone(); + let endpoint_url = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse" + ); + let client = self.client.clone(); + let request = request.clone(); + + let api_key = self.api_key.clone(); + // Start a thread to run the request in the background. + tokio::spawn(async move { + let req = client + .post(&endpoint_url) + .header("x-goog-api-key", api_key) + .json(&request); + + let mut event_source = match EventSource::new(req) { + Ok(event_source) => event_source, + Err(e) => { + cloned_queue.push(Some(Err(e.into()))); + return; + } + }; + while let Some(event) = event_source.next().await { + match event { + Ok(event) => { + if let Event::Message(event) = event { + let response: serde_json::error::Result = + serde_json::from_str(&event.data); + + match response { + Ok(response) => { + let result = response.into_result(); + let finished = match &result { + Ok(result) => result.candidates[0].finish_reason.is_some(), + Err(_) => true, + }; + cloned_queue.push(Some(result)); + if finished { + break; + } + } + Err(_) => { + tracing::error!("Error parsing message: {}", event.data); + break; + } + } + } + } + Err(e) => { + tracing::error!("Error in event source: {:?}", e); + break; + } + } + } + cloned_queue.push(None); + }); + + // Return the queue that will receive the responses. + queue + } + + pub async fn generate_content( + &self, + request: &GenerateContentRequest, + model: &str, + ) -> Result { + let endpoint_url: String = format!( + "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent", + ); + let resp = self + .client + .post(&endpoint_url) + .header("x-goog-api-key", &self.api_key) + .json(&request) + .send() + .await?; + + let status = resp.status(); + let txt_json = resp.text().await?; + tracing::debug!("generate_content response: {:?}", txt_json); + + if !status.is_success() { + if let Ok(gemini_error) = + serde_json::from_str::(&txt_json) + { + return Err(Error::GeminiError(gemini_error)); + } + // Fallback if parsing fails, though it should ideally match GeminiApiError + return Err(Error::GenericApiError { + status: status.as_u16(), + body: txt_json, + }); + } + + match serde_json::from_str::(&txt_json) { + Ok(response) => Ok(response.into_result()?), + Err(e) => { + tracing::error!("Failed to parse response: {} with error {}", txt_json, e); + Err(e.into()) + } + } + } + + /// Prompts a conversation to the model. + pub async fn prompt_conversation(&self, messages: &[Message], model: &str) -> Result { + let request = GenerateContentRequest { + contents: messages + .iter() + .map(|m| Content { + role: Some(m.role), + parts: Some(vec![Part::from_text(m.text.clone())]), + }) + .collect(), + generation_config: None, + tools: None, + system_instruction: None, + safety_settings: None, + }; + + let response = self.generate_content(&request, model).await?; + + // Check for errors in the response. + let mut candidates = GeminiClient::collect_text_from_response(&response); + + match candidates.pop() { + Some(text) => Ok(Message::new(Role::Model, &text)), + None => Err(Error::NoCandidatesError), + } + } + + fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec { + response + .candidates + .iter() + .filter_map(Candidate::get_text) + .collect::>() + } + + pub async fn text_embeddings( + &self, + request: &TextEmbeddingRequest, + model: &str, + ) -> Result { + let endpoint_url = + format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict"); + let resp = self + .client + .post(&endpoint_url) + .header("x-goog-api-key", &self.api_key) + .json(&request) + .send() + .await?; + let txt_json = resp.text().await?; + tracing::debug!("text_embeddings response: {:?}", txt_json); + Ok(serde_json::from_str::(&txt_json)?) + } + + pub async fn count_tokens( + &self, + request: &CountTokensRequest, + model: &str, + ) -> Result { + let endpoint_url = + format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens"); + let resp = self + .client + .post(&endpoint_url) + .header("x-goog-api-key", &self.api_key) + .json(&request) + .send() + .await?; + + let txt_json = resp.text().await?; + tracing::debug!("count_tokens response: {:?}", txt_json); + Ok(serde_json::from_str(&txt_json)?) + } + + pub async fn predict_image( + &self, + request: &PredictImageRequest, + model: &str, + ) -> Result { + let endpoint_url = + format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict"); + + let resp = self + .client + .post(&endpoint_url) + .header("x-goog-api-key", &self.api_key) + .json(&request) + .send() + .await?; + + let status = resp.status(); + let txt_json = resp.text().await?; + + if !status.is_success() { + if let Ok(gemini_error) = + serde_json::from_str::(&txt_json) + { + return Err(Error::GeminiError(gemini_error)); + } + return Err(Error::GenericApiError { + status: status.as_u16(), + body: txt_json, + }); + } + + match serde_json::from_str::(&txt_json) { + Ok(response) => Ok(response), + Err(e) => { + error!(response = txt_json, error = ?e, "Failed to parse response"); + Err(e.into()) + } + } + } +} diff --git a/src/dialogue.rs b/src/dialogue.rs new file mode 100644 index 0000000..d82b8ff --- /dev/null +++ b/src/dialogue.rs @@ -0,0 +1,42 @@ +use serde::{Deserialize, Serialize}; + +use crate::{client::GeminiClient, error::Result, types::Role}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub text: String, +} + +impl Message { + pub fn new(role: Role, text: &str) -> Self { + Message { + role, + text: text.to_string(), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Dialogue { + model: String, + messages: Vec, +} + +impl Dialogue { + pub fn new(model: &str) -> Self { + Dialogue { + model: model.to_string(), + messages: vec![], + } + } + + pub async fn do_turn(&mut self, gemini: &GeminiClient, message: &str) -> Result { + self.messages.push(Message::new(Role::User, message)); + let response = gemini + .prompt_conversation(&self.messages, &self.model) + .await?; + self.messages.push(response.clone()); + Ok(response) + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..fe05b60 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,96 @@ +use std::fmt::Display; + +use reqwest_eventsource::CannotCloneRequestError; + +use crate::types; + +pub type Result = std::result::Result; + +#[derive(Debug)] +pub enum Error { + Env(std::env::VarError), + HttpClient(reqwest::Error), + Serde(serde_json::Error), + VertexError(types::VertexApiError), + GeminiError(types::GeminiApiError), + NoCandidatesError, + CannotCloneRequestError(CannotCloneRequestError), + EventSourceError(Box), + EventSourceClosedError, + GenericApiError { status: u16, body: String }, +} + +impl Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self { + Error::Env(e) => write!(f, "Environment variable error: {e}"), + Error::HttpClient(e) => write!(f, "HTTP Client error: {e}"), + Error::Serde(e) => write!(f, "Serde error: {e}"), + Error::VertexError(e) => { + write!(f, "Vertex error: {e}") + } + Error::GeminiError(e) => { + write!(f, "Gemini error: {e}") + } + Error::NoCandidatesError => { + write!(f, "No candidates returned for the prompt") + } + Error::CannotCloneRequestError(e) => { + write!(f, "Cannot clone request: {e}") + } + Error::EventSourceError(e) => { + write!(f, "EventSourrce Error: {e}") + } + Error::EventSourceClosedError => { + write!(f, "EventSource closed error") + } + Error::GenericApiError { status, body } => { + write!(f, "API error (status {status}): {body}") + } + } + } +} + +impl std::error::Error for Error {} + +impl From for Error { + fn from(e: reqwest::Error) -> Self { + Error::HttpClient(e) + } +} + +impl From for Error { + fn from(e: std::env::VarError) -> Self { + Error::Env(e) + } +} + +impl From for Error { + fn from(e: serde_json::Error) -> Self { + Error::Serde(e) + } +} + +impl From for Error { + fn from(e: types::VertexApiError) -> Self { + Error::VertexError(e) + } +} + +impl From for Error { + fn from(e: types::GeminiApiError) -> Self { + Error::GeminiError(e) + } +} + +impl From for Error { + fn from(e: CannotCloneRequestError) -> Self { + Error::CannotCloneRequestError(e) + } +} + +impl From for Error { + fn from(e: reqwest_eventsource::Error) -> Self { + Error::EventSourceError(Box::new(e)) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b97fb81 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,10 @@ +mod client; +mod dialogue; +pub mod error; +mod types; + +pub mod prelude { + pub use crate::client::*; + pub use crate::dialogue::*; + pub use crate::types::*; +} diff --git a/src/types/common.rs b/src/types/common.rs new file mode 100644 index 0000000..e2c4287 --- /dev/null +++ b/src/types/common.rs @@ -0,0 +1,159 @@ +use std::{fmt::Display, str::FromStr, vec}; + +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +#[derive(Clone, Default, Debug, Serialize, Deserialize)] +pub struct Content { + pub role: Option, + pub parts: Option>, +} + +impl Content { + pub fn get_text(&self) -> Option { + self.parts.as_ref().map(|parts| { + parts + .iter() + .filter_map(|part| match &part.data { + PartData::Text(text) => Some(text.clone()), + _ => None, + }) + .collect::() + }) + } + + pub fn builder() -> ContentBuilder { + ContentBuilder::default() + } +} + +#[derive(Default)] +pub struct ContentBuilder { + content: Content, +} + +impl ContentBuilder { + pub fn add_text_part>(self, text: T) -> Self { + self.add_part(Part::from_text(text.into())) + } + + pub fn add_part(mut self, part: Part) -> Self { + match &mut self.content.parts { + Some(parts) => parts.push(part), + None => self.content.parts = Some(vec![part]), + } + self + } + + pub fn role(mut self, role: Role) -> Self { + self.content.role = Some(role); + self + } + + pub fn build(self) -> Content { + self.content + } +} + +#[derive(Clone, Copy, Debug, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Model, +} + +impl Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let role_str = match self { + Role::User => "user", + Role::Model => "model", + }; + f.write_str(role_str) + } +} + +impl FromStr for Role { + type Err = (); + + fn from_str(s: &str) -> std::result::Result { + match s { + "user" => Ok(Role::User), + "model" => Ok(Role::Model), + _ => Err(()), + } + } +} + +/// See https://ai.google.dev/api/caching#Part +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Part { + #[serde(skip_serializing_if = "Option::is_none")] + pub thought: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thought_signature: Option, + // This is of a Struct type, a Map of values, so either a Value or Map are appropriate. + //See https://protobuf.dev/reference/protobuf/google.protobuf/#struct + #[serde(skip_serializing_if = "Option::is_none")] + pub part_metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub media_resolution: Option, // TODO: Create type for media_resolution. + #[serde(flatten)] + pub data: PartData, // Create enum for data. +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum PartData { + Text(String), + // https://ai.google.dev/api/caching#Blob + InlineData { + mime_type: String, + data: String, + }, + // https://ai.google.dev/api/caching#FunctionCall + FunctionCall { + id: Option, + name: String, + args: Option, + }, + // https://ai.google.dev/api/caching#FunctionResponse + FunctionResponse { + id: Option, + name: String, + response: Value, + will_continue: Option, + // TODO: Add remaining fields. + }, + FileData(Value), + ExecutableCode(Value), + CodeExecutionResult(Value), +} + +impl Part { + pub fn from_text(text: String) -> Self { + Self { + thought: None, + thought_signature: None, + part_metadata: None, + media_resolution: None, + data: PartData::Text(text), + } + } +} + +#[cfg(test)] +mod test { + use crate::types::Part; + + #[test] + pub fn parses_text_part() { + let input = r#" + { + "text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)", + "thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU=" + }"#; + + let _ = serde_json::from_str::(input).unwrap(); + } +} diff --git a/src/types/count_tokens.rs b/src/types/count_tokens.rs new file mode 100644 index 0000000..79af363 --- /dev/null +++ b/src/types/count_tokens.rs @@ -0,0 +1,49 @@ +use serde::{Deserialize, Serialize}; + +use super::Content; + +#[derive(Debug, Serialize, Deserialize)] +pub struct CountTokensRequest { + pub contents: Content, +} + +impl CountTokensRequest { + pub fn builder() -> CountTokensRequestBuilder { + CountTokensRequestBuilder::default() + } +} + +#[derive(Default)] +pub struct CountTokensRequestBuilder { + contents: Content, +} + +impl CountTokensRequestBuilder { + pub fn from_prompt(prompt: &str) -> Self { + CountTokensRequestBuilder { + contents: Content { + parts: Some(vec![super::Part::from_text(prompt.to_string())]), + ..Default::default() + }, + } + } + + pub fn build(self) -> CountTokensRequest { + CountTokensRequest { + contents: self.contents, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum CountTokensResponse { + #[serde(rename_all = "camelCase")] + Ok { + total_tokens: i32, + total_billable_characters: u32, + }, + Error { + error: super::VertexApiError, + }, +} diff --git a/src/types/error.rs b/src/types/error.rs new file mode 100644 index 0000000..80b69f1 --- /dev/null +++ b/src/types/error.rs @@ -0,0 +1,67 @@ +use std::fmt::Formatter; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct VertexApiError { + pub code: i32, + pub message: String, + pub status: String, + pub details: Option>, +} + +impl core::fmt::Display for VertexApiError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + writeln!(f, "Vertex API Error {} - {}", self.code, self.message)?; + Ok(()) + } +} + +impl std::error::Error for VertexApiError {} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct GeminiApiError { + pub error: VertexApiError, +} + +impl core::fmt::Display for GeminiApiError { + fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { + write!(f, "Gemini API Error {} - {}", self.error.code, self.error.message) + } +} + +impl std::error::Error for GeminiApiError {} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Link { + pub description: String, + pub url: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "@type")] +pub enum ErrorType { + #[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")] + ErrorInfo { metadata: ErrorInfoMetadata }, + + #[serde(rename = "type.googleapis.com/google.rpc.Help")] + Help { links: Vec }, + + #[serde(rename = "type.googleapis.com/google.rpc.BadRequest")] + BadRequest { + #[serde(rename = "fieldViolations")] + field_violations: Vec, + }, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ErrorInfoMetadata { + pub service: String, + pub consumer: String, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FieldViolation { + pub field: String, + pub description: String, +} diff --git a/src/types/generate_content.rs b/src/types/generate_content.rs new file mode 100644 index 0000000..20d8dc8 --- /dev/null +++ b/src/types/generate_content.rs @@ -0,0 +1,622 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; + +use super::{Content, VertexApiError}; +use crate::error::Result; + +#[derive(Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentRequest { + pub contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_settings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_instruction: Option, +} + +impl GenerateContentRequest { + pub fn builder() -> GenerateContentRequestBuilder { + GenerateContentRequestBuilder::new() + } +} + +pub struct GenerateContentRequestBuilder { + request: GenerateContentRequest, +} + +impl GenerateContentRequestBuilder { + fn new() -> Self { + GenerateContentRequestBuilder { + request: GenerateContentRequest::default(), + } + } + + pub fn contents(mut self, contents: Vec) -> Self { + self.request.contents = contents; + self + } + + pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self { + self.request.generation_config = Some(generation_config); + self + } + + pub fn tools(mut self, tools: Vec) -> Self { + self.request.tools = Some(tools); + self + } + + pub fn safety_settings(mut self, safety_settings: Vec) -> Self { + self.request.safety_settings = Some(safety_settings); + self + } + + pub fn system_instruction(mut self, system_instruction: Content) -> Self { + self.request.system_instruction = Some(system_instruction); + self + } + + pub fn build(self) -> GenerateContentRequest { + self.request + } +} + +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct Tools { + #[serde(skip_serializing_if = "Option::is_none")] + pub function_declarations: Option>, + + #[serde(rename = "googleSearchRetrieval")] + #[serde(skip_serializing_if = "Option::is_none")] + pub google_search_retrieval: Option, + + #[serde(skip_serializing_if = "Option::is_none")] + pub google_search: Option, +} + +#[derive(Clone, Default, Serialize, Deserialize)] +pub struct GoogleSearch {} + +#[derive(Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct DynamicRetrievalConfig { + pub mode: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub dynamic_threshold: Option, +} + +impl Default for DynamicRetrievalConfig { + fn default() -> Self { + Self { + mode: "MODE_DYNAMIC".to_string(), + dynamic_threshold: Some(0.7), + } + } +} + +#[derive(Clone, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GoogleSearchRetrieval { + pub dynamic_retrieval_config: DynamicRetrievalConfig, +} + +#[derive(Clone, Debug, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +pub struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub candidate_count: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_mime_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_schema: Option, +} + +impl GenerationConfig { + pub fn builder() -> GenerationConfigBuilder { + GenerationConfigBuilder::new() + } +} + +pub struct GenerationConfigBuilder { + generation_config: GenerationConfig, +} + +impl GenerationConfigBuilder { + fn new() -> Self { + Self { + generation_config: Default::default(), + } + } + + pub fn max_output_tokens>(mut self, max_output_tokens: T) -> Self { + self.generation_config.max_output_tokens = Some(max_output_tokens.into()); + self + } + + pub fn temperature>(mut self, temperature: T) -> Self { + self.generation_config.temperature = Some(temperature.into()); + self + } + + pub fn top_p>(mut self, top_p: T) -> Self { + self.generation_config.top_p = Some(top_p.into()); + self + } + + pub fn top_k>(mut self, top_k: T) -> Self { + self.generation_config.top_k = Some(top_k.into()); + self + } + + pub fn stop_sequences>>(mut self, stop_sequences: T) -> Self { + self.generation_config.stop_sequences = Some(stop_sequences.into()); + self + } + + pub fn candidate_count>(mut self, candidate_count: T) -> Self { + self.generation_config.candidate_count = Some(candidate_count.into()); + self + } + + pub fn response_mime_type>(mut self, response_mime_type: T) -> Self { + self.generation_config.response_mime_type = Some(response_mime_type.into()); + self + } + + pub fn response_schema>(mut self, response_schema: T) -> Self { + self.generation_config.response_schema = Some(response_schema.into()); + self + } + + pub fn build(self) -> GenerationConfig { + self.generation_config + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetySetting { + pub category: HarmCategory, + pub threshold: HarmBlockThreshold, + #[serde(skip_serializing_if = "Option::is_none")] + pub method: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmCategory { + #[serde(rename = "HARM_CATEGORY_UNSPECIFIED")] + Unspecified, + #[serde(rename = "HARM_CATEGORY_HATE_SPEECH")] + HateSpeech, + #[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")] + DangerousContent, + #[serde(rename = "HARM_CATEGORY_HARASSMENT")] + Harassment, + #[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")] + SexuallyExplicit, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmBlockThreshold { + #[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")] + Unspecified, + #[serde(rename = "BLOCK_LOW_AND_ABOVE")] + BlockLowAndAbove, + #[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")] + BlockMediumAndAbove, + #[serde(rename = "BLOCK_ONLY_HIGH")] + BlockOnlyHigh, + #[serde(rename = "BLOCK_NONE")] + BlockNone, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum HarmBlockMethod { + #[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")] + Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED + #[serde(rename = "SEVERITY")] + Severity, // SEVERITY + #[serde(rename = "PROBABILITY")] + Probability, // PROBABILITY +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Candidate { + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub citation_metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_ratings: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub finish_reason: Option, + pub index: u32, +} + +impl Candidate { + pub fn get_text(&self) -> Option { + match &self.content { + Some(content) => content.get_text(), + None => None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Citation { + pub start_index: Option, + pub end_index: Option, + pub uri: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct CitationMetadata { + #[serde(alias = "citationSources")] + pub citations: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SafetyRating { + pub category: String, + pub probability: String, + pub probability_score: Option, + pub severity: String, + pub severity_score: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UsageMetadata { + pub candidates_token_count: Option, + pub prompt_token_count: Option, + pub total_token_count: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionDeclaration { + pub name: String, + pub description: String, + // TODO: add behaviour field - https://ai.google.dev/api/caching#Behavior + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parameters_json_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_json_schema: Option, +} + +/// See https://ai.google.dev/api/caching#FunctionResponse +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub name: String, + pub response: Value, + // TODO: Add missing properties from docs. +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct FunctionParametersProperty { + pub r#type: String, + pub description: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum GenerateContentResponse { + Ok(GenerateContentResponseResult), + Error(GenerateContentResponseError), +} + +impl From for Result { + fn from(val: GenerateContentResponse) -> Self { + match val { + GenerateContentResponse::Ok(result) => Ok(result), + GenerateContentResponse::Error(error) => Err(error.error.into()), + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct GenerateContentResponseResult { + pub candidates: Vec, + pub usage_metadata: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GenerateContentResponseError { + pub error: VertexApiError, +} + +impl GenerateContentResponse { + pub fn into_result(self) -> Result { + match self { + GenerateContentResponse::Ok(result) => Ok(result), + GenerateContentResponse::Error(error) => Err(error.error.into()), + } + } +} + +#[cfg(test)] +mod tests { + use crate::types::{Candidate, UsageMetadata}; + + use super::{GenerateContentResponse, GenerateContentResponseResult}; + + #[test] + pub fn parses_empty_metadata_response() { + let input = r#"{"candidates": [{"content": {"role": "model","parts": [{"text": "-"}]}}],"usageMetadata": {}}"#; + serde_json::from_str::(input).unwrap(); + } + + #[test] + pub fn parses_usage_metadata() { + let input = r#" + { + "promptTokenCount": 11, + "candidatesTokenCount": 202, + "totalTokenCount": 1041, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 11 + } + ], + "thoughtsTokenCount": 828 + }"#; + let _ = serde_json::from_str::(input).unwrap(); + } + + #[test] + pub fn parses_candidate() { + let input = r#" + { + "content": { + "parts": [ + { + "text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)", + "thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU=" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + }"#; + let _ = serde_json::from_str::(input).unwrap(); + } + + #[test] + pub fn parses_google_ai_response() { + let input = r#" + { + "candidates": [ + { + "content": { + "parts": [ + { + "text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)", + "thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU=" + } + ], + "role": "model" + }, + "finishReason": "STOP", + "index": 0 + } + ], + "usageMetadata": { + "promptTokenCount": 11, + "candidatesTokenCount": 202, + "totalTokenCount": 1041, + "promptTokensDetails": [ + { + "modality": "TEXT", + "tokenCount": 11 + } + ], + "thoughtsTokenCount": 828 + }, + "modelVersion": "gemini-3-pro-preview", + "responseId": "2uUdaYPkG73WvdIP2aPs2Ak" + } + "#; + let _ = serde_json::from_str::(input).unwrap(); + } + + #[test] + pub fn parses_max_tokens_response() { + let input = r#"{ + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Service workers are powerful and absolutely worth learning. They let you deliver an entirely new level of experience to your users. Your site can load instantly . It can work offline . It can be installed as a platform-specific app and feel every bit as polished—but with the reach and freedom of the web." + } + ] + }, + "finishReason": "MAX_TOKENS", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.03882902, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.05781161 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.07626997, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.06705628 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.05749328, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.027532939 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.12929276, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.17838266 + } + ], + "citationMetadata": { + "citations": [ + { + "endIndex": 151, + "uri": "https://web.dev/service-worker-mindset/" + }, + { + "startIndex": 93, + "endIndex": 297, + "uri": "https://web.dev/service-worker-mindset/" + }, + { + "endIndex": 297 + } + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 12069, + "candidatesTokenCount": 61, + "totalTokenCount": 12130 + } + }"#; + serde_json::from_str::(input).unwrap(); + } + + #[test] + fn parses_candidates_without_content() { + let input = r#"{ + "candidates": [ + { + "finishReason": "RECITATION", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.08021325, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.0721122 + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.19360436, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.1066906 + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.07751766, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.040769264 + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "probabilityScore": 0.030792166, + "severity": "HARM_SEVERITY_NEGLIGIBLE", + "severityScore": 0.04138472 + } + ], + "citationMetadata": { + "citations": [ + { + "startIndex": 1108, + "endIndex": 1250, + "uri": "https://chrome.google.com/webstore/detail/autocontrol-shortcut-mana/lkaihdpfpifdlgoapbfocpmekbokmcfd?hl=zh-TW" + } + ] + } + } + ], + "usageMetadata": { + "promptTokenCount": 577, + "totalTokenCount": 577 + } + }"#; + serde_json::from_str::(input).unwrap(); + } + + #[test] + fn parses_safety_rating_without_scores() { + let input = r#"{ + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + { + "text": "Return text" + } + ] + }, + "finishReason": "STOP", + "safetyRatings": [ + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + "severity": "HARM_SEVERITY_NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + "severity": "HARM_SEVERITY_NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + "severity": "HARM_SEVERITY_NEGLIGIBLE" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + "severity": "HARM_SEVERITY_NEGLIGIBLE" + } + ] + } + ], + "usageMetadata": { + "promptTokenCount": 5492, + "candidatesTokenCount": 1256, + "totalTokenCount": 6748 + } + }"#; + serde_json::from_str::(input).unwrap(); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs new file mode 100644 index 0000000..8fa1ed5 --- /dev/null +++ b/src/types/mod.rs @@ -0,0 +1,13 @@ +mod common; +mod count_tokens; +mod error; +mod generate_content; +mod predict_image; +mod text_embeddings; + +pub use common::*; +pub use count_tokens::*; +pub use error::*; +pub use generate_content::*; +pub use predict_image::*; +pub use text_embeddings::*; diff --git a/src/types/predict_image.rs b/src/types/predict_image.rs new file mode 100644 index 0000000..2eaef5f --- /dev/null +++ b/src/types/predict_image.rs @@ -0,0 +1,187 @@ +use serde::{Deserialize, Serialize}; +use serde_with::base64::Base64; +use serde_with::serde_as; + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageRequest { + pub instances: Vec, + pub parameters: PredictImageRequestParameters, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageRequestPrompt { + /// The text prompt for the image. + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: up to 480 tokens. + /// - `imagen-3.0-fast-generate-001`: up to 480 tokens. + /// - `imagegeneration@006`: up to 128 tokens. + /// - `imagegeneration@005`: up to 128 tokens. + /// - `imagegeneration@002`: up to 64 tokens. + pub prompt: String, +} + +#[derive(Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageRequestParameters { + /// The number of images to generate. The default value is 4. + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: 1 to 4. + /// - `imagen-3.0-fast-generate-001`: 1 to 4. + /// - `imagegeneration@006`: 1 to 4. + /// - `imagegeneration@005`: 1 to 4. + /// - `imagegeneration@002`: 1 to 8. + pub sample_count: i32, + + /// The random seed for image generation. This is not available when addWatermark is set to + /// true. + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + + /// Optional. An optional parameter to use an LLM-based prompt rewriting feature to deliver + /// higher quality images that better reflect the original prompt's intent. Disabling this + /// feature may impact image quality and prompt adherence + #[serde(skip_serializing_if = "Option::is_none")] + pub enhance_prompt: Option, + + /// A description of what to discourage in the generated images. + /// The following models support this parameter: + /// - `imagen-3.0-generate-001`: up to 480 tokens. + /// - `imagen-3.0-fast-generate-001`: up to 480 tokens. + /// - `imagegeneration@006`: up to 128 tokens. + /// - `imagegeneration@005`: up to 128 tokens. + /// - `imagegeneration@002`: up to 64 tokens. + #[serde(skip_serializing_if = "Option::is_none")] + pub negative_prompt: Option, + + /// The aspect ratio for the image. The default value is "1:1". + /// The following models support different values for this parameter: + /// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3". + /// - `imagegeneration@005`: "1:1" or "9:16". + /// - `imagegeneration@002`: "1:1". + #[serde(skip_serializing_if = "Option::is_none")] + pub aspect_ratio: Option, + + /// Describes the output image format in an `PredictImageRequestParametersOutputOptions + /// object. + #[serde(skip_serializing_if = "Option::is_none")] + pub output_options: Option, + + /// Describes the style for the generated images. The following values are supported: + /// - "photograph" + /// - "digital_art" + /// - "landscape" + /// - "sketch" + /// - "watercolor" + /// - "cyberpunk" + /// - "pop_art" + /// + /// Pre-defined styles is only supported for model imagegeneration@002 + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_image_style: Option, + + /// Allow generation of people by the model. The following values are supported: + /// - `"dont_allow"`: Disallow the inclusion of people or faces in images. + /// - `"allow_adult"`: Allow generation of adults only. + /// - `"allow_all"`: Allow generation of people of all ages. + /// + /// The default value is `"allow_adult"`. + /// + /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and + /// `imagegeneration@006` only. + #[serde(skip_serializing_if = "Option::is_none")] + pub person_generation: Option, + + /// Optional. The language code that corresponds to your text prompt language. + /// The following values are supported: + /// - auto: Automatic detection. If Imagen detects a supported language, the prompt and an + /// optional negative prompt are translated to English. If the language detected isn't + /// supported, Imagen uses the input text verbatim, which might result in an unexpected + /// output. No error code is returned. + /// - en: English (if omitted, the default value) + /// - zh or zh-CN: Chinese (simplified) + /// - zh-TW: Chinese (traditional) + /// - hi: Hindi + /// - ja: Japanese + /// - ko: Korean + /// - pt: Portuguese + /// - es: Spanish + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + + /// Adds a filter level to safety filtering. The following values are supported: + /// + /// - "block_low_and_above": Strongest filtering level, most strict blocking. + /// Deprecated value: "block_most". + /// - "block_medium_and_above": Block some problematic prompts and responses. + /// Deprecated value: "block_some". + /// - "block_only_high": Reduces the number of requests blocked due to safety filters. May + /// increase objectionable content generated by Imagen. Deprecated value: "block_few". + /// - "block_none": Block very few problematic prompts and responses. Access to this feature + /// is restricted. Previous field value: "block_fewest". + /// + /// The default value is "block_medium_and_above". + /// + /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and + /// `imagegeneration@006` only. + #[serde(skip_serializing_if = "Option::is_none")] + pub safety_setting: Option, + + /// Add an invisible watermark to the generated images. The default value is `false` for the + /// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the + /// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models. + #[serde(skip_serializing_if = "Option::is_none")] + pub add_watermark: Option, + + /// Cloud Storage URI to store the generated images. + #[serde(skip_serializing_if = "Option::is_none")] + pub storage_uri: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageRequestParametersOutputOptions { + /// The image format that the output should be saved as. The following values are supported: + /// + /// - "image/png": Save as a PNG image + /// - "image/jpeg": Save as a JPEG image + /// + /// The default value is "image/png".v + pub mime_type: Option, + + /// The level of compression if the output type is "image/jpeg". + /// Accepted values are 0 through 100. The default value is 75. + pub compression_quality: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct PredictImageResponse { + pub predictions: Vec, +} + +#[serde_as] +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PredictImageResponsePrediction { + #[serde_as(as = "Base64")] + pub bytes_base64_encoded: Vec, + pub mime_type: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PersonGeneration { + DontAllow, + AllowAdult, + AllowAll, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PredictImageSafetySetting { + BlockLowAndAbove, + BlockMediumAndAbove, + BlockOnlyHigh, + BlockNone, +} diff --git a/src/types/text_embeddings.rs b/src/types/text_embeddings.rs new file mode 100644 index 0000000..3f1917d --- /dev/null +++ b/src/types/text_embeddings.rs @@ -0,0 +1,61 @@ +use serde::{Deserialize, Serialize}; + +use crate::error::{Error, Result}; + +use super::VertexApiError; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingRequest { + pub instances: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingRequestInstance { + pub content: String, + pub task_type: String, + pub title: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum TextEmbeddingResponse { + Ok(TextEmbeddingResponseOk), + Error { error: VertexApiError }, +} + +impl TextEmbeddingResponse { + pub fn into_result(self) -> Result { + self.into() + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingResponseOk { + pub predictions: Vec, +} + +impl From for Result { + fn from(value: TextEmbeddingResponse) -> Self { + match value { + TextEmbeddingResponse::Ok(ok) => Ok(ok), + TextEmbeddingResponse::Error { error } => Err(Error::VertexError(error)), + } + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingPrediction { + pub embeddings: TextEmbeddingResult, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingResult { + pub statistics: TextEmbeddingStatistics, + pub values: Vec, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TextEmbeddingStatistics { + pub truncated: bool, + pub token_count: u32, +}