Initial Commit
This commit is contained in:
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
.env
|
||||
27
Cargo.toml
Normal file
27
Cargo.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "google-genai"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
deadqueue = "0.2"
|
||||
reqwest = { version = "0.12", features = ["json", "gzip"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = { version = "1"}
|
||||
serde_with = { version = "3.9", features = ["base64"]}
|
||||
tracing = "0.1"
|
||||
tokio = { version = "1" }
|
||||
tokio-stream = "0.1.17"
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.16.0"
|
||||
dialoguer = "0.12.0"
|
||||
dotenvy = "0.15.7"
|
||||
image = "0.25.2"
|
||||
indicatif = "0.18.0"
|
||||
tokio = { version = "1.47.0", features = ["full"] }
|
||||
tracing-subscriber = "0.3.20"
|
||||
|
||||
113
examples/function-call.rs
Normal file
113
examples/function-call.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
use std::{env, error::Error};
|
||||
|
||||
use google_genai::prelude::{
|
||||
Content, FunctionDeclaration, GeminiClient, GenerateContentRequest, Part, PartData, Role, Tools,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
let api_key = env::var("GEMINI_API_KEY")?;
|
||||
let gemini_client = GeminiClient::new(api_key);
|
||||
|
||||
let mut contents = vec![
|
||||
Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("What is the sbrubbles value of 213 and 231?")
|
||||
.build(),
|
||||
];
|
||||
|
||||
let sum_parameters = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"left": {
|
||||
"type": "integer",
|
||||
"description": "The first value for the sbrubbles calculation"
|
||||
},
|
||||
"right": {
|
||||
"type": "integer",
|
||||
"description": "The second value for the sbrubbles calculation"
|
||||
}
|
||||
},
|
||||
"required": ["left", "right"]
|
||||
});
|
||||
|
||||
let sum_result = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"result": {
|
||||
"type": "integer",
|
||||
"description": "The sbrubbles value calculation result",
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let sum_function = FunctionDeclaration {
|
||||
name: String::from("sbrubbles"),
|
||||
description: String::from("Calculates the sbrubbles value"),
|
||||
parameters: None,
|
||||
parameters_json_schema: Some(sum_parameters),
|
||||
response: None,
|
||||
response_json_schema: Some(sum_result),
|
||||
};
|
||||
|
||||
println!("{}", serde_json::to_string_pretty(&sum_function).unwrap());
|
||||
|
||||
let tools = Tools {
|
||||
function_declarations: Some(vec![sum_function]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let request = GenerateContentRequest::builder()
|
||||
.contents(contents.clone())
|
||||
.tools(vec![tools.clone()])
|
||||
.build();
|
||||
|
||||
let mut response = gemini_client
|
||||
.generate_content(&request, "gemini-3-pro-preview")
|
||||
.await?;
|
||||
|
||||
while let Some(candidate) = response.candidates.last()
|
||||
&& let Some(content) = &candidate.content
|
||||
&& let Some(parts) = &content.parts
|
||||
&& let Some(part) = parts.last()
|
||||
&& let PartData::FunctionCall { id, name, args, .. } = &part.data
|
||||
{
|
||||
contents.push(content.clone());
|
||||
match args {
|
||||
Some(args) => println!("Function call: {name}, {args}"),
|
||||
None => println!("Function call: {name}"),
|
||||
}
|
||||
|
||||
contents.push(Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part {
|
||||
data: PartData::FunctionResponse {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
response: json!({"result": 1234}),
|
||||
will_continue: None,
|
||||
},
|
||||
media_resolution: None,
|
||||
part_metadata: None,
|
||||
thought: None,
|
||||
thought_signature: part.thought_signature.clone(),
|
||||
}]),
|
||||
});
|
||||
let request = GenerateContentRequest::builder()
|
||||
.contents(contents.clone())
|
||||
.tools(vec![tools.clone()])
|
||||
.build();
|
||||
|
||||
println!("{contents:?}");
|
||||
response = gemini_client
|
||||
.generate_content(&request, "gemini-3-pro-preview")
|
||||
.await?;
|
||||
}
|
||||
|
||||
println!("Response: {:#?}", response.candidates);
|
||||
Ok(())
|
||||
}
|
||||
26
examples/generate-content.rs
Normal file
26
examples/generate-content.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use std::{env, error::Error};
|
||||
|
||||
use google_genai::prelude::{Content, GeminiClient, GenerateContentRequest, Role};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
let api_key = env::var("GEMINI_API_KEY")?;
|
||||
let gemini_client = GeminiClient::new(api_key);
|
||||
let prompt = vec![
|
||||
Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("What is the airspeed of an unladen swallow?")
|
||||
.build(),
|
||||
];
|
||||
|
||||
let request = GenerateContentRequest::builder().contents(prompt).build();
|
||||
let response = gemini_client
|
||||
.generate_content(&request, "gemini-3-pro-preview")
|
||||
.await?;
|
||||
|
||||
println!("Response: {:?}", response.candidates[0].get_text().unwrap());
|
||||
Ok(())
|
||||
}
|
||||
50
examples/generate-image.rs
Normal file
50
examples/generate-image.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use std::{env, error::Error, io::Cursor};
|
||||
|
||||
use google_genai::prelude::{
|
||||
GeminiClient, PersonGeneration, PredictImageRequest, PredictImageRequestParameters,
|
||||
PredictImageRequestParametersOutputOptions, PredictImageRequestPrompt,
|
||||
PredictImageSafetySetting,
|
||||
};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let _ = dotenvy::dotenv();
|
||||
|
||||
let api_key = env::var("GEMINI_API_KEY")?;
|
||||
let gemini_client = GeminiClient::new(api_key);
|
||||
|
||||
let prompt = "
|
||||
Create an image of a tuxedo cat riding a rocket to the moon.";
|
||||
let request = PredictImageRequest {
|
||||
instances: vec![PredictImageRequestPrompt {
|
||||
prompt: prompt.to_string(),
|
||||
}],
|
||||
parameters: PredictImageRequestParameters {
|
||||
sample_count: 1,
|
||||
aspect_ratio: Some("1:1".to_string()),
|
||||
output_options: Some(PredictImageRequestParametersOutputOptions {
|
||||
mime_type: Some("image/jpeg".to_string()),
|
||||
compression_quality: Some(75),
|
||||
}),
|
||||
person_generation: Some(PersonGeneration::AllowAdult),
|
||||
safety_setting: Some(PredictImageSafetySetting::BlockLowAndAbove),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
println!("Request: {:#?}", serde_json::to_string(&request).unwrap());
|
||||
|
||||
let mut result = gemini_client
|
||||
.predict_image(&request, "imagen-4.0-generate-001")
|
||||
.await?;
|
||||
|
||||
let result = result.predictions.pop().unwrap();
|
||||
|
||||
let format = ImageFormat::from_mime_type(result.mime_type).unwrap();
|
||||
let img =
|
||||
ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?;
|
||||
img.save("output.jpg")?;
|
||||
Ok(())
|
||||
}
|
||||
309
src/client.rs
Normal file
309
src/client.rs
Normal file
@@ -0,0 +1,309 @@
|
||||
use crate::error::Result as GeminiResult;
|
||||
use std::sync::Arc;
|
||||
use std::vec;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
|
||||
use deadqueue::unlimited::Queue;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use tracing::error;
|
||||
|
||||
use crate::dialogue::Message;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::Part;
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
|
||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct GeminiClient {
|
||||
client: reqwest::Client,
|
||||
api_key: String,
|
||||
}
|
||||
|
||||
unsafe impl Send for GeminiClient {}
|
||||
unsafe impl Sync for GeminiClient {}
|
||||
|
||||
impl GeminiClient {
|
||||
pub fn new(api_key: String) -> Self {
|
||||
GeminiClient {
|
||||
client: reqwest::Client::new(),
|
||||
api_key,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_content_stream(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||
let endpoint_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse"
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request);
|
||||
|
||||
let event_source = EventSource::new(req).unwrap();
|
||||
|
||||
let mapped = event_source.filter_map(|event| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => {
|
||||
return Some(Err(Error::EventSourceClosedError));
|
||||
}
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let Event::Message(event_message) = event else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let gemini_response: GenerateContentResponse =
|
||||
match serde_json::from_str(&event_message.data) {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let gemini_response = match gemini_response.into_result() {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
Some(Ok(gemini_response))
|
||||
});
|
||||
Ok(mapped)
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Arc<Queue<Option<Result<GenerateContentResponseResult>>>> {
|
||||
let queue = Arc::new(Queue::<Option<Result<GenerateContentResponseResult>>>::new());
|
||||
|
||||
// Clone the queue and other necessary data to move into the async block.
|
||||
let cloned_queue = queue.clone();
|
||||
let endpoint_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:streamGenerateContent?alt=sse"
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
|
||||
let api_key = self.api_key.clone();
|
||||
// Start a thread to run the request in the background.
|
||||
tokio::spawn(async move {
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", api_key)
|
||||
.json(&request);
|
||||
|
||||
let mut event_source = match EventSource::new(req) {
|
||||
Ok(event_source) => event_source,
|
||||
Err(e) => {
|
||||
cloned_queue.push(Some(Err(e.into())));
|
||||
return;
|
||||
}
|
||||
};
|
||||
while let Some(event) = event_source.next().await {
|
||||
match event {
|
||||
Ok(event) => {
|
||||
if let Event::Message(event) = event {
|
||||
let response: serde_json::error::Result<GenerateContentResponse> =
|
||||
serde_json::from_str(&event.data);
|
||||
|
||||
match response {
|
||||
Ok(response) => {
|
||||
let result = response.into_result();
|
||||
let finished = match &result {
|
||||
Ok(result) => result.candidates[0].finish_reason.is_some(),
|
||||
Err(_) => true,
|
||||
};
|
||||
cloned_queue.push(Some(result));
|
||||
if finished {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
tracing::error!("Error parsing message: {}", event.data);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("Error in event source: {:?}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
cloned_queue.push(None);
|
||||
});
|
||||
|
||||
// Return the queue that will receive the responses.
|
||||
queue
|
||||
}
|
||||
|
||||
pub async fn generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<GenerateContentResponseResult> {
|
||||
let endpoint_url: String = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent",
|
||||
);
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("generate_content response: {:?}", txt_json);
|
||||
|
||||
if !status.is_success() {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(Error::GeminiError(gemini_error));
|
||||
}
|
||||
// Fallback if parsing fails, though it should ideally match GeminiApiError
|
||||
return Err(Error::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
}
|
||||
|
||||
match serde_json::from_str::<GenerateContentResponse>(&txt_json) {
|
||||
Ok(response) => Ok(response.into_result()?),
|
||||
Err(e) => {
|
||||
tracing::error!("Failed to parse response: {} with error {}", txt_json, e);
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Prompts a conversation to the model.
|
||||
pub async fn prompt_conversation(&self, messages: &[Message], model: &str) -> Result<Message> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: messages
|
||||
.iter()
|
||||
.map(|m| Content {
|
||||
role: Some(m.role),
|
||||
parts: Some(vec![Part::from_text(m.text.clone())]),
|
||||
})
|
||||
.collect(),
|
||||
generation_config: None,
|
||||
tools: None,
|
||||
system_instruction: None,
|
||||
safety_settings: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, model).await?;
|
||||
|
||||
// Check for errors in the response.
|
||||
let mut candidates = GeminiClient::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(text) => Ok(Message::new(Role::Model, &text)),
|
||||
None => Err(Error::NoCandidatesError),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||
response
|
||||
.candidates
|
||||
.iter()
|
||||
.filter_map(Candidate::get_text)
|
||||
.collect::<Vec<String>>()
|
||||
}
|
||||
|
||||
pub async fn text_embeddings(
|
||||
&self,
|
||||
request: &TextEmbeddingRequest,
|
||||
model: &str,
|
||||
) -> Result<TextEmbeddingResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
||||
}
|
||||
|
||||
pub async fn count_tokens(
|
||||
&self,
|
||||
request: &CountTokensRequest,
|
||||
model: &str,
|
||||
) -> Result<CountTokensResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:countTokens");
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
|
||||
pub async fn predict_image(
|
||||
&self,
|
||||
request: &PredictImageRequest,
|
||||
model: &str,
|
||||
) -> Result<PredictImageResponse> {
|
||||
let endpoint_url =
|
||||
format!("https://generativelanguage.googleapis.com/v1beta/models/{model}:predict");
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.header("x-goog-api-key", &self.api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = resp.status();
|
||||
let txt_json = resp.text().await?;
|
||||
|
||||
if !status.is_success() {
|
||||
if let Ok(gemini_error) =
|
||||
serde_json::from_str::<crate::types::GeminiApiError>(&txt_json)
|
||||
{
|
||||
return Err(Error::GeminiError(gemini_error));
|
||||
}
|
||||
return Err(Error::GenericApiError {
|
||||
status: status.as_u16(),
|
||||
body: txt_json,
|
||||
});
|
||||
}
|
||||
|
||||
match serde_json::from_str::<PredictImageResponse>(&txt_json) {
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => {
|
||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
42
src/dialogue.rs
Normal file
42
src/dialogue.rs
Normal file
@@ -0,0 +1,42 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{client::GeminiClient, error::Result, types::Role};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Message {
|
||||
pub role: Role,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: Role, text: &str) -> Self {
|
||||
Message {
|
||||
role,
|
||||
text: text.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Dialogue {
|
||||
model: String,
|
||||
messages: Vec<Message>,
|
||||
}
|
||||
|
||||
impl Dialogue {
|
||||
pub fn new(model: &str) -> Self {
|
||||
Dialogue {
|
||||
model: model.to_string(),
|
||||
messages: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_turn(&mut self, gemini: &GeminiClient, message: &str) -> Result<Message> {
|
||||
self.messages.push(Message::new(Role::User, message));
|
||||
let response = gemini
|
||||
.prompt_conversation(&self.messages, &self.model)
|
||||
.await?;
|
||||
self.messages.push(response.clone());
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
96
src/error.rs
Normal file
96
src/error.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest_eventsource::CannotCloneRequestError;
|
||||
|
||||
use crate::types;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Error {
|
||||
Env(std::env::VarError),
|
||||
HttpClient(reqwest::Error),
|
||||
Serde(serde_json::Error),
|
||||
VertexError(types::VertexApiError),
|
||||
GeminiError(types::GeminiApiError),
|
||||
NoCandidatesError,
|
||||
CannotCloneRequestError(CannotCloneRequestError),
|
||||
EventSourceError(Box<reqwest_eventsource::Error>),
|
||||
EventSourceClosedError,
|
||||
GenericApiError { status: u16, body: String },
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self {
|
||||
Error::Env(e) => write!(f, "Environment variable error: {e}"),
|
||||
Error::HttpClient(e) => write!(f, "HTTP Client error: {e}"),
|
||||
Error::Serde(e) => write!(f, "Serde error: {e}"),
|
||||
Error::VertexError(e) => {
|
||||
write!(f, "Vertex error: {e}")
|
||||
}
|
||||
Error::GeminiError(e) => {
|
||||
write!(f, "Gemini error: {e}")
|
||||
}
|
||||
Error::NoCandidatesError => {
|
||||
write!(f, "No candidates returned for the prompt")
|
||||
}
|
||||
Error::CannotCloneRequestError(e) => {
|
||||
write!(f, "Cannot clone request: {e}")
|
||||
}
|
||||
Error::EventSourceError(e) => {
|
||||
write!(f, "EventSourrce Error: {e}")
|
||||
}
|
||||
Error::EventSourceClosedError => {
|
||||
write!(f, "EventSource closed error")
|
||||
}
|
||||
Error::GenericApiError { status, body } => {
|
||||
write!(f, "API error (status {status}): {body}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl From<reqwest::Error> for Error {
|
||||
fn from(e: reqwest::Error) -> Self {
|
||||
Error::HttpClient(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<std::env::VarError> for Error {
|
||||
fn from(e: std::env::VarError) -> Self {
|
||||
Error::Env(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<serde_json::Error> for Error {
|
||||
fn from(e: serde_json::Error) -> Self {
|
||||
Error::Serde(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<types::VertexApiError> for Error {
|
||||
fn from(e: types::VertexApiError) -> Self {
|
||||
Error::VertexError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<types::GeminiApiError> for Error {
|
||||
fn from(e: types::GeminiApiError) -> Self {
|
||||
Error::GeminiError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CannotCloneRequestError> for Error {
|
||||
fn from(e: CannotCloneRequestError) -> Self {
|
||||
Error::CannotCloneRequestError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_eventsource::Error> for Error {
|
||||
fn from(e: reqwest_eventsource::Error) -> Self {
|
||||
Error::EventSourceError(Box::new(e))
|
||||
}
|
||||
}
|
||||
10
src/lib.rs
Normal file
10
src/lib.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
mod client;
|
||||
mod dialogue;
|
||||
pub mod error;
|
||||
mod types;
|
||||
|
||||
pub mod prelude {
|
||||
pub use crate::client::*;
|
||||
pub use crate::dialogue::*;
|
||||
pub use crate::types::*;
|
||||
}
|
||||
159
src/types/common.rs
Normal file
159
src/types/common.rs
Normal file
@@ -0,0 +1,159 @@
|
||||
use std::{fmt::Display, str::FromStr, vec};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Clone, Default, Debug, Serialize, Deserialize)]
|
||||
pub struct Content {
|
||||
pub role: Option<Role>,
|
||||
pub parts: Option<Vec<Part>>,
|
||||
}
|
||||
|
||||
impl Content {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
self.parts.as_ref().map(|parts| {
|
||||
parts
|
||||
.iter()
|
||||
.filter_map(|part| match &part.data {
|
||||
PartData::Text(text) => Some(text.clone()),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<String>()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn builder() -> ContentBuilder {
|
||||
ContentBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContentBuilder {
|
||||
content: Content,
|
||||
}
|
||||
|
||||
impl ContentBuilder {
|
||||
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
|
||||
self.add_part(Part::from_text(text.into()))
|
||||
}
|
||||
|
||||
pub fn add_part(mut self, part: Part) -> Self {
|
||||
match &mut self.content.parts {
|
||||
Some(parts) => parts.push(part),
|
||||
None => self.content.parts = Some(vec![part]),
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn role(mut self, role: Role) -> Self {
|
||||
self.content.role = Some(role);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Content {
|
||||
self.content
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let role_str = match self {
|
||||
Role::User => "user",
|
||||
Role::Model => "model",
|
||||
};
|
||||
f.write_str(role_str)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Role {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s {
|
||||
"user" => Ok(Role::User),
|
||||
"model" => Ok(Role::Model),
|
||||
_ => Err(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#Part
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Part {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thought: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub thought_signature: Option<String>,
|
||||
// This is of a Struct type, a Map of values, so either a Value or Map<String, Value> are appropriate.
|
||||
//See https://protobuf.dev/reference/protobuf/google.protobuf/#struct
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub part_metadata: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub media_resolution: Option<Value>, // TODO: Create type for media_resolution.
|
||||
#[serde(flatten)]
|
||||
pub data: PartData, // Create enum for data.
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum PartData {
|
||||
Text(String),
|
||||
// https://ai.google.dev/api/caching#Blob
|
||||
InlineData {
|
||||
mime_type: String,
|
||||
data: String,
|
||||
},
|
||||
// https://ai.google.dev/api/caching#FunctionCall
|
||||
FunctionCall {
|
||||
id: Option<String>,
|
||||
name: String,
|
||||
args: Option<Value>,
|
||||
},
|
||||
// https://ai.google.dev/api/caching#FunctionResponse
|
||||
FunctionResponse {
|
||||
id: Option<String>,
|
||||
name: String,
|
||||
response: Value,
|
||||
will_continue: Option<bool>,
|
||||
// TODO: Add remaining fields.
|
||||
},
|
||||
FileData(Value),
|
||||
ExecutableCode(Value),
|
||||
CodeExecutionResult(Value),
|
||||
}
|
||||
|
||||
impl Part {
|
||||
pub fn from_text(text: String) -> Self {
|
||||
Self {
|
||||
thought: None,
|
||||
thought_signature: None,
|
||||
part_metadata: None,
|
||||
media_resolution: None,
|
||||
data: PartData::Text(text),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::types::Part;
|
||||
|
||||
#[test]
|
||||
pub fn parses_text_part() {
|
||||
let input = r#"
|
||||
{
|
||||
"text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)",
|
||||
"thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU="
|
||||
}"#;
|
||||
|
||||
let _ = serde_json::from_str::<Part>(input).unwrap();
|
||||
}
|
||||
}
|
||||
49
src/types/count_tokens.rs
Normal file
49
src/types/count_tokens.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Content;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequest {
|
||||
pub fn builder() -> CountTokensRequestBuilder {
|
||||
CountTokensRequestBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CountTokensRequestBuilder {
|
||||
contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequestBuilder {
|
||||
pub fn from_prompt(prompt: &str) -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content {
|
||||
parts: Some(vec![super::Part::from_text(prompt.to_string())]),
|
||||
..Default::default()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build(self) -> CountTokensRequest {
|
||||
CountTokensRequest {
|
||||
contents: self.contents,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum CountTokensResponse {
|
||||
#[serde(rename_all = "camelCase")]
|
||||
Ok {
|
||||
total_tokens: i32,
|
||||
total_billable_characters: u32,
|
||||
},
|
||||
Error {
|
||||
error: super::VertexApiError,
|
||||
},
|
||||
}
|
||||
67
src/types/error.rs
Normal file
67
src/types/error.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
use std::fmt::Formatter;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct VertexApiError {
|
||||
pub code: i32,
|
||||
pub message: String,
|
||||
pub status: String,
|
||||
pub details: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for VertexApiError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
writeln!(f, "Vertex API Error {} - {}", self.code, self.message)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for VertexApiError {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct GeminiApiError {
|
||||
pub error: VertexApiError,
|
||||
}
|
||||
|
||||
impl core::fmt::Display for GeminiApiError {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
|
||||
write!(f, "Gemini API Error {} - {}", self.error.code, self.error.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for GeminiApiError {}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Link {
|
||||
pub description: String,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "@type")]
|
||||
pub enum ErrorType {
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.ErrorInfo")]
|
||||
ErrorInfo { metadata: ErrorInfoMetadata },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.Help")]
|
||||
Help { links: Vec<Link> },
|
||||
|
||||
#[serde(rename = "type.googleapis.com/google.rpc.BadRequest")]
|
||||
BadRequest {
|
||||
#[serde(rename = "fieldViolations")]
|
||||
field_violations: Vec<FieldViolation>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ErrorInfoMetadata {
|
||||
pub service: String,
|
||||
pub consumer: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct FieldViolation {
|
||||
pub field: String,
|
||||
pub description: String,
|
||||
}
|
||||
622
src/types/generate_content.rs
Normal file
622
src/types/generate_content.rs
Normal file
@@ -0,0 +1,622 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{Content, VertexApiError};
|
||||
use crate::error::Result;
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tools>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_instruction: Option<Content>,
|
||||
}
|
||||
|
||||
impl GenerateContentRequest {
|
||||
pub fn builder() -> GenerateContentRequestBuilder {
|
||||
GenerateContentRequestBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerateContentRequestBuilder {
|
||||
request: GenerateContentRequest,
|
||||
}
|
||||
|
||||
impl GenerateContentRequestBuilder {
|
||||
fn new() -> Self {
|
||||
GenerateContentRequestBuilder {
|
||||
request: GenerateContentRequest::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contents(mut self, contents: Vec<Content>) -> Self {
|
||||
self.request.contents = contents;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn generation_config(mut self, generation_config: GenerationConfig) -> Self {
|
||||
self.request.generation_config = Some(generation_config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn tools(mut self, tools: Vec<Tools>) -> Self {
|
||||
self.request.tools = Some(tools);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn safety_settings(mut self, safety_settings: Vec<SafetySetting>) -> Self {
|
||||
self.request.safety_settings = Some(safety_settings);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn system_instruction(mut self, system_instruction: Content) -> Self {
|
||||
self.request.system_instruction = Some(system_instruction);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerateContentRequest {
|
||||
self.request
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct Tools {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
|
||||
#[serde(rename = "googleSearchRetrieval")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub google_search: Option<GoogleSearch>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GoogleSearch {}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DynamicRetrievalConfig {
|
||||
pub mode: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dynamic_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for DynamicRetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: "MODE_DYNAMIC".to_string(),
|
||||
dynamic_threshold: Some(0.7),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GoogleSearchRetrieval {
|
||||
pub dynamic_retrieval_config: DynamicRetrievalConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerationConfig {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_output_tokens: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_k: Option<i32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub candidate_count: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_mime_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_schema: Option<Value>,
|
||||
}
|
||||
|
||||
impl GenerationConfig {
|
||||
pub fn builder() -> GenerationConfigBuilder {
|
||||
GenerationConfigBuilder::new()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct GenerationConfigBuilder {
|
||||
generation_config: GenerationConfig,
|
||||
}
|
||||
|
||||
impl GenerationConfigBuilder {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
generation_config: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_output_tokens<T: Into<i32>>(mut self, max_output_tokens: T) -> Self {
|
||||
self.generation_config.max_output_tokens = Some(max_output_tokens.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn temperature<T: Into<f32>>(mut self, temperature: T) -> Self {
|
||||
self.generation_config.temperature = Some(temperature.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_p<T: Into<f32>>(mut self, top_p: T) -> Self {
|
||||
self.generation_config.top_p = Some(top_p.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn top_k<T: Into<i32>>(mut self, top_k: T) -> Self {
|
||||
self.generation_config.top_k = Some(top_k.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stop_sequences<T: Into<Vec<String>>>(mut self, stop_sequences: T) -> Self {
|
||||
self.generation_config.stop_sequences = Some(stop_sequences.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn candidate_count<T: Into<u32>>(mut self, candidate_count: T) -> Self {
|
||||
self.generation_config.candidate_count = Some(candidate_count.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_mime_type<T: Into<String>>(mut self, response_mime_type: T) -> Self {
|
||||
self.generation_config.response_mime_type = Some(response_mime_type.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn response_schema<T: Into<Value>>(mut self, response_schema: T) -> Self {
|
||||
self.generation_config.response_schema = Some(response_schema.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> GenerationConfig {
|
||||
self.generation_config
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetySetting {
|
||||
pub category: HarmCategory,
|
||||
pub threshold: HarmBlockThreshold,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub method: Option<HarmBlockMethod>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmCategory {
|
||||
#[serde(rename = "HARM_CATEGORY_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
#[serde(rename = "HARM_CATEGORY_HATE_SPEECH")]
|
||||
HateSpeech,
|
||||
#[serde(rename = "HARM_CATEGORY_DANGEROUS_CONTENT")]
|
||||
DangerousContent,
|
||||
#[serde(rename = "HARM_CATEGORY_HARASSMENT")]
|
||||
Harassment,
|
||||
#[serde(rename = "HARM_CATEGORY_SEXUALLY_EXPLICIT")]
|
||||
SexuallyExplicit,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockThreshold {
|
||||
#[serde(rename = "HARM_BLOCK_THRESHOLD_UNSPECIFIED")]
|
||||
Unspecified,
|
||||
#[serde(rename = "BLOCK_LOW_AND_ABOVE")]
|
||||
BlockLowAndAbove,
|
||||
#[serde(rename = "BLOCK_MEDIUM_AND_ABOVE")]
|
||||
BlockMediumAndAbove,
|
||||
#[serde(rename = "BLOCK_ONLY_HIGH")]
|
||||
BlockOnlyHigh,
|
||||
#[serde(rename = "BLOCK_NONE")]
|
||||
BlockNone,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum HarmBlockMethod {
|
||||
#[serde(rename = "HARM_BLOCK_METHOD_UNSPECIFIED")]
|
||||
Unspecified, // HARM_BLOCK_METHOD_UNSPECIFIED
|
||||
#[serde(rename = "SEVERITY")]
|
||||
Severity, // SEVERITY
|
||||
#[serde(rename = "PROBABILITY")]
|
||||
Probability, // PROBABILITY
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Candidate {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<Content>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_ratings: Option<Vec<SafetyRating>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub finish_reason: Option<String>,
|
||||
pub index: u32,
|
||||
}
|
||||
|
||||
impl Candidate {
|
||||
pub fn get_text(&self) -> Option<String> {
|
||||
match &self.content {
|
||||
Some(content) => content.get_text(),
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Citation {
|
||||
pub start_index: Option<i32>,
|
||||
pub end_index: Option<i32>,
|
||||
pub uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct CitationMetadata {
|
||||
#[serde(alias = "citationSources")]
|
||||
pub citations: Vec<Citation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SafetyRating {
|
||||
pub category: String,
|
||||
pub probability: String,
|
||||
pub probability_score: Option<f32>,
|
||||
pub severity: String,
|
||||
pub severity_score: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub candidates_token_count: Option<u32>,
|
||||
pub prompt_token_count: Option<u32>,
|
||||
pub total_token_count: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
// TODO: add behaviour field - https://ai.google.dev/api/caching#Behavior
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters_json_schema: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_json_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// See https://ai.google.dev/api/caching#FunctionResponse
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
pub name: String,
|
||||
pub response: Value,
|
||||
// TODO: Add missing properties from docs.
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionParametersProperty {
|
||||
pub r#type: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum GenerateContentResponse {
|
||||
Ok(GenerateContentResponseResult),
|
||||
Error(GenerateContentResponseError),
|
||||
}
|
||||
|
||||
impl From<GenerateContentResponse> for Result<GenerateContentResponseResult> {
|
||||
fn from(val: GenerateContentResponse) -> Self {
|
||||
match val {
|
||||
GenerateContentResponse::Ok(result) => Ok(result),
|
||||
GenerateContentResponse::Error(error) => Err(error.error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentResponseResult {
|
||||
pub candidates: Vec<Candidate>,
|
||||
pub usage_metadata: Option<UsageMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct GenerateContentResponseError {
|
||||
pub error: VertexApiError,
|
||||
}
|
||||
|
||||
impl GenerateContentResponse {
|
||||
pub fn into_result(self) -> Result<GenerateContentResponseResult> {
|
||||
match self {
|
||||
GenerateContentResponse::Ok(result) => Ok(result),
|
||||
GenerateContentResponse::Error(error) => Err(error.error.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::types::{Candidate, UsageMetadata};
|
||||
|
||||
use super::{GenerateContentResponse, GenerateContentResponseResult};
|
||||
|
||||
#[test]
|
||||
pub fn parses_empty_metadata_response() {
|
||||
let input = r#"{"candidates": [{"content": {"role": "model","parts": [{"text": "-"}]}}],"usageMetadata": {}}"#;
|
||||
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parses_usage_metadata() {
|
||||
let input = r#"
|
||||
{
|
||||
"promptTokenCount": 11,
|
||||
"candidatesTokenCount": 202,
|
||||
"totalTokenCount": 1041,
|
||||
"promptTokensDetails": [
|
||||
{
|
||||
"modality": "TEXT",
|
||||
"tokenCount": 11
|
||||
}
|
||||
],
|
||||
"thoughtsTokenCount": 828
|
||||
}"#;
|
||||
let _ = serde_json::from_str::<UsageMetadata>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parses_candidate() {
|
||||
let input = r#"
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)",
|
||||
"thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU="
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0
|
||||
}"#;
|
||||
let _ = serde_json::from_str::<Candidate>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parses_google_ai_response() {
|
||||
let input = r#"
|
||||
{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": "**What do you mean? An African or a European swallow?**\n\nIf you are looking for the actual physics rather than the *Monty Python and the Holy Grail* reference, here is the breakdown:\n\n**1. The European Swallow**\nBased on an analysis published by Jonathan Corum (using data on the Strouhal number of cruising flight), the estimated airspeed velocity of an unladen European Swallow is roughly **11 meters per second**, or **24 miles per hour**.\n\n**2. The African Swallow**\nData on the African swallow is scarcer, mostly because—as the guard in the movie points out—African swallows are non-migratory. However, since they are similar in size to their European counterparts, their cruising speed would likely be comparable.\n\nBut of course, the real question is: *Could it carry a coconut?* (A five-ounce bird could not carry a one-pound coconut. It is a simple question of weight ratios.)",
|
||||
"thoughtSignature": "EqcZCqQZAdHtim/53UNFI7YRLcEDch1I/mLfWNT6lVjgXb7RsNnYn8JLU8Y6UhAi4nkLJ/nK2l44Y+JJZimQ2rLpRfdlBAPkhVsuZYenAY7MRXG9GQrSzz1elR+L6FAb0dyb9snnGz5NdlKCyS9VIWKIhghmHA60oEnEUexaJD2mq3ZV4kJ8R/d+UJEEdOD9CdlnB1WnOvHaiT15mLSj8JxclI+1mml86b5hjA0F+MLVWesa4gjo6/OfNo1k+tA+JioUAu8hgZ5DJttNxs/BvrLMyY/+d6qm40Ht45BuNlKUjFTkrUOIx5oAld3PnNj804Ou3F/sv8i5UMh9TcWyuiOjP3lZU5t1GEKQJ/YY9CxN/Zl71Kzk51Z+92IV2tKLqZVsEkrIr5o33QmNRTIeX0zMSQRdhlTBPuwSa+l91SV56cPK0I7P6UPguc3qGD8E3wfUC+fByDzX4JZ6OuhyrwcCCgbyjnBgI/FoWBA364cKONEH69p851Jy+zRaI9hWKKOQ/hqHqpWL266vgnALkvjcfZS3Frc6rRTvRIzetVufrJM3i9OAfnoLPZz5crraRQgUpgcPUd9fYhl59PIK35jRaENXunDUa8NE/J8kObcZE+910NxsUo7LzsGssr6UOPM6slKhnocnbqCrrNLhoF0jLXbSObuCXKh5HuGV8Y51UdsK6oUuct+ScfOZGBl+/6LhaGmlS0Ab58R7CO8UqhX4j91H8YW6xtDTQoAIXNU2j4Zq7lkpH0b5Vv7ZhFnbbc1OgTtboTcKwyRXgZFlBa6NNIb7GvRMyKdWW+sHXFAXGohZubp7DXsr6gQ/8eqcTuiiLKChRbY6MhG14OkGw4/LcuBAxEg6Fy7JX3tlMfto3LcfhFVvlmM1XuWACR9OJLr49YAkBYsMWl95qK5tSG0Wo/hAqjcPWPszrzK9Uo9AsDpsCHGnX57Ytcsi60y+jnV7iQqhoWtaT+UJW9FbxOPpKTsQw0k2GPM/1d+ulMz2IYPrN/Bsuk34OyAUID1zEUnSro0Q4camHfW2wnJvW77rLmfqO2b0M4+UuEgbgB/dyQtICsNndaO1x6S3pL8/typqoakwx/9xg02QVzLLRvfs4Su9eSAsKL/QfQCI9dmS8O0kvA1DqbUdxO6HfrfCVpGKoLajB4dZ/1nplNFFL+ap7vXOU9F4foXemT4f71T3S93NWb6gFU8jB8WxNaoWVBoeuP7iJNMqqBZPvV9SJ94lELlV/LZKlZ+pqQML/Gfe565AmXD34ekgE5ZGkwQxSoP8BksbDnL41GxEZtvWHcr+kSZK2FoTBwsXBye43qy1ZFYV+guSPqgsy5S215c2r4g+zfJ2vlC5+k2621Dwex7POA68LrtfbyeFJ8gQY7nZMPNp2gZQHmY/imA1Fb0jiCfMzYUiWumJeyOeiSUE5p/slwV0SryaYtT73fjx37F/iUAE5zl6yEo8v45aiB2XNgxdTU4bjHEFD+sj/6DGp27ukt6vLxN/QhmPvU7yYUA+u1WbQblof6VN7AwhVUqgqUx9Je0kSXPrI12K/2yC6eZnGuXeicqwIxCQWh9z9o24NzUkaiVC7VnSItVgXDWwviwAe4H1LxNU9y6j+Y0R8iGclRQVN8haBc1x7BWO6raGsLRrKblykBsIydnuz1Bvjk4eEaoH1rCzzIiuj1ZqG3bo/bLxjJw1h1KmnXkywo8alCusMIog71a3FQnST+idwJ9+tJU31rqMxinD1kUwG5ZYmFnpRZWHD57gsa5rzFptjbnkUxfBhHD3+7mO6qlgMidjzfv77MuFWRVyglDMD+eNvlX6vmPm93Qq4rDZTDssck6IYCaQ6TuqXJ2WEal0HDgaX/rlyhUL/4T7Ptk2/QoQqekUasvbjPhpn25R9AGTIcEwdoVsK2kC4ftvtkc2g1jE4PK2fLqe6sNfCEebZT18nx5FdgELbkSB+ss3aLfvWVVC0EJJmdlW+F1mxxPnkfvwcCfj4YKsfhEMoiPxbs0As2dtbaV9xcrhFlGZFoA/idudJqRPEuZvhtiJ2L0MQMuDWqT6kDr6wqnAghj2olacMb9rU5IlK9hfoCalMp7/adEJLpzJ7RdZd6o8cGq0D2v9lsT/2OJtq+kiMIG3gzIDrHSCK7v3XFpmA6DcMsgUHyYGSe1Mfe6fD+mPXyKWEi+hp3SJjDHa3Xk0bx5java0fZc/q/t9yxxjijIVGlRrduMj0GQpi3JHOL/JZoGWHrMSQFBmLIEypj+Dp1nImOja7j69VlK6q1dxELdx1sE5eIzTpk0/bRZ3oyqFtXYwyWUJsx5evdJSPIGbM8lgQsV8yO9U8LRot2BhWyfsU8NWRsHY5ihYb2K/Y9saE1iML4uqvIAK36eG9DuRaz2zIa6K3G5Xr/U8c0BxUxNNcWIra7TPyVmIXhLm85ghX9qKWNM2YQO/02tvIAI/9+8qANblayjg31j+FjME1NNGQg3jxA28QyfN39b0Fg8sD5MWmHP6MtvfVwx0JM88n1eCJiZ0No5BFUOB/EfgtiXp48ledg66cLjPmU9rjKPNyK4iUsRO7IY9X0/7L4M+d+8tBOy14Bfjn0ELi6HdF5+HVgWp3DViCn8iX4HCVrTX9S4/ZrgJVDJdI5axuGlsaH3VqCV0Rfes/p3MfcjUVOpBja+byTWMbM0ZONjrF3NAtzwZwLN+QDVEVS8Hso11mYsL6IvEbKsGYySBcX6qZ57p0MlPeC0GPPy0DkDca19W/fWFkrlPP60plNymq+c9HZ1Ghmg9YSGluckJLidqR6wuCSSkyaSwjJaJYnu4MIfXrLP4Q0UmKwvVJFSNqhtDSaus+U2+m8sl6CadTs4trw2iVh78/Wpghvido18f7A40MFo8E3OLN9XEgXA2FLMPrGiZM3JFTMutokburAgTAxs7CmbqilP4ArWvxEvG+TbmCatA5PhhGibms3OO910cjaToRUXriE8K7kHRM7Miui7qDcCM+wcgPOV+sYNNucAAbseGi+Mej1tmMLTUO4k8q2bRcadMaijASasX6Q8k6k1YGy89HTh1UkwCLdd6F4eYHsDFpMGjwJ2I1fJ/4lmTAUYOHP3n4p4ovOSoptgIul9sty7iqZnQlkQHeVWQSwMzyBbcxTqA6GDsdNk5GF+Wjaf3C3F+uOhRY+yD0wbb43d3rpEMPkThbTTsN8ricg0bDSIWnM2FKfsQ0QFbZuC2JrkeSEZuLd3RldLsUXBzrQl2ub49oztmjEQSu6GePyz9LAeQRJd6EUQ4/I/vu1SLyHcXZAch4zrzk2u+7OWehE+i/CGzRWL14/x+z3PPmguYOqS1rJdCWDIKlIXD9nZc/heFhQ4QiV2pvr0ElYHCDnAq/SgpPC7EFy4BGmz6cMJ2Az44cijzOFbYZ1+rkbxvLV4Q2QVDj5tgBNYrV7FYBs+B0kF3D/ijbp1JGowGDsXJC1KaUpu01OL9962042O3b4RIU6NsGa0irMip/IAlFYhEW72Aj6oNvqNKDf7VjT3GYvRRz51zPMaKymBLCDw2lSrz7tTkN8L3w7dyLzBpzNI894Id3B6lf+ummAp+w0y0Q/jQnNzUFJznXIoais7JwcC+jxkolAW6iCXwGYGbYLTgV1jKH1GdJf10yzMo/obPF2F4vtRITmq3PGRV1DEm9ELbu3ajhSP4vh9eUqxki/ORrJibn6MVBz1GtzOzFBZ8br2ZZLCxqq4bTVj/BXPngVZ6bxmxdn7rf15a4IcPRZ9hPEl/M3vIl6cSJLb3M45lADfDtBW70dXMFAcof2ipkngcOf2NY/dYuGUMMyOp/Xetvy4kFY2ye2nU0PEq0GhwxxCB/zrGzxprC7W93sVETAlPXyb9yirlo4elyaNIZMt+sqlUHGoFyK3xDPlkNAwrsQWgghNwMrtZ7Fm683n38X9HVwgGjJpeoODfIph+f0vDl+ncO2GywdSbJXQg5Tf5PTONvZb+8Kd2F7Lv8mljtqAKHoh/b7MyogGvA914hUL3jKFClnAaD9xXWCK83stRL2Hqg2PmY+aNwB3m/Y54QEYdq+Xu7nIWo8EkncKTB4GwLb7Cyep88E5WNnyaU1Y337xAEGN9403pqCp+abrFgMOLl1MAPoWXNEGsQIEqVJECkPgpdR1eU83LjPXjSthCe5mo2Vc35IgOOA94UEDfaXyRqQEE5CH+QRdXCc4oMKt3cTUHiPlPbHayKVH5d1lntDxMgJ5tSN1kwcFQMKYXJdYSZqatoYNar0tnSF2EuGPs2ium1h4Il/NKCPiZySDbYRwDITMu+RVMvr5CbmXHF93bz/d0n8Qg8A2qmrU="
|
||||
}
|
||||
],
|
||||
"role": "model"
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 11,
|
||||
"candidatesTokenCount": 202,
|
||||
"totalTokenCount": 1041,
|
||||
"promptTokensDetails": [
|
||||
{
|
||||
"modality": "TEXT",
|
||||
"tokenCount": 11
|
||||
}
|
||||
],
|
||||
"thoughtsTokenCount": 828
|
||||
},
|
||||
"modelVersion": "gemini-3-pro-preview",
|
||||
"responseId": "2uUdaYPkG73WvdIP2aPs2Ak"
|
||||
}
|
||||
"#;
|
||||
let _ = serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parses_max_tokens_response() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Service workers are powerful and absolutely worth learning. They let you deliver an entirely new level of experience to your users. Your site can load instantly . It can work offline . It can be installed as a platform-specific app and feel every bit as polished—but with the reach and freedom of the web."
|
||||
}
|
||||
]
|
||||
},
|
||||
"finishReason": "MAX_TOKENS",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.03882902,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.05781161
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.07626997,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.06705628
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.05749328,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.027532939
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.12929276,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.17838266
|
||||
}
|
||||
],
|
||||
"citationMetadata": {
|
||||
"citations": [
|
||||
{
|
||||
"endIndex": 151,
|
||||
"uri": "https://web.dev/service-worker-mindset/"
|
||||
},
|
||||
{
|
||||
"startIndex": 93,
|
||||
"endIndex": 297,
|
||||
"uri": "https://web.dev/service-worker-mindset/"
|
||||
},
|
||||
{
|
||||
"endIndex": 297
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 12069,
|
||||
"candidatesTokenCount": 61,
|
||||
"totalTokenCount": 12130
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponseResult>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_candidates_without_content() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"finishReason": "RECITATION",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.08021325,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.0721122
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.19360436,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.1066906
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.07751766,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.040769264
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"probabilityScore": 0.030792166,
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE",
|
||||
"severityScore": 0.04138472
|
||||
}
|
||||
],
|
||||
"citationMetadata": {
|
||||
"citations": [
|
||||
{
|
||||
"startIndex": 1108,
|
||||
"endIndex": 1250,
|
||||
"uri": "https://chrome.google.com/webstore/detail/autocontrol-shortcut-mana/lkaihdpfpifdlgoapbfocpmekbokmcfd?hl=zh-TW"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 577,
|
||||
"totalTokenCount": 577
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_safety_rating_without_scores() {
|
||||
let input = r#"{
|
||||
"candidates": [
|
||||
{
|
||||
"content": {
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Return text"
|
||||
}
|
||||
]
|
||||
},
|
||||
"finishReason": "STOP",
|
||||
"safetyRatings": [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"probability": "NEGLIGIBLE",
|
||||
"severity": "HARM_SEVERITY_NEGLIGIBLE"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"usageMetadata": {
|
||||
"promptTokenCount": 5492,
|
||||
"candidatesTokenCount": 1256,
|
||||
"totalTokenCount": 6748
|
||||
}
|
||||
}"#;
|
||||
serde_json::from_str::<GenerateContentResponse>(input).unwrap();
|
||||
}
|
||||
}
|
||||
13
src/types/mod.rs
Normal file
13
src/types/mod.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
mod generate_content;
|
||||
mod predict_image;
|
||||
mod text_embeddings;
|
||||
|
||||
pub use common::*;
|
||||
pub use count_tokens::*;
|
||||
pub use error::*;
|
||||
pub use generate_content::*;
|
||||
pub use predict_image::*;
|
||||
pub use text_embeddings::*;
|
||||
187
src/types/predict_image.rs
Normal file
187
src/types/predict_image.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::base64::Base64;
|
||||
use serde_with::serde_as;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequest {
|
||||
pub instances: Vec<PredictImageRequestPrompt>,
|
||||
pub parameters: PredictImageRequestParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageRequestPrompt {
|
||||
/// The text prompt for the image.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParameters {
|
||||
/// The number of images to generate. The default value is 4.
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: 1 to 4.
|
||||
/// - `imagen-3.0-fast-generate-001`: 1 to 4.
|
||||
/// - `imagegeneration@006`: 1 to 4.
|
||||
/// - `imagegeneration@005`: 1 to 4.
|
||||
/// - `imagegeneration@002`: 1 to 8.
|
||||
pub sample_count: i32,
|
||||
|
||||
/// The random seed for image generation. This is not available when addWatermark is set to
|
||||
/// true.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u32>,
|
||||
|
||||
/// Optional. An optional parameter to use an LLM-based prompt rewriting feature to deliver
|
||||
/// higher quality images that better reflect the original prompt's intent. Disabling this
|
||||
/// feature may impact image quality and prompt adherence
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enhance_prompt: Option<bool>,
|
||||
|
||||
/// A description of what to discourage in the generated images.
|
||||
/// The following models support this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||
/// - `imagegeneration@006`: up to 128 tokens.
|
||||
/// - `imagegeneration@005`: up to 128 tokens.
|
||||
/// - `imagegeneration@002`: up to 64 tokens.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub negative_prompt: Option<String>,
|
||||
|
||||
/// The aspect ratio for the image. The default value is "1:1".
|
||||
/// The following models support different values for this parameter:
|
||||
/// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||
/// - `imagegeneration@005`: "1:1" or "9:16".
|
||||
/// - `imagegeneration@002`: "1:1".
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub aspect_ratio: Option<String>,
|
||||
|
||||
/// Describes the output image format in an `PredictImageRequestParametersOutputOptions
|
||||
/// object.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_options: Option<PredictImageRequestParametersOutputOptions>,
|
||||
|
||||
/// Describes the style for the generated images. The following values are supported:
|
||||
/// - "photograph"
|
||||
/// - "digital_art"
|
||||
/// - "landscape"
|
||||
/// - "sketch"
|
||||
/// - "watercolor"
|
||||
/// - "cyberpunk"
|
||||
/// - "pop_art"
|
||||
///
|
||||
/// Pre-defined styles is only supported for model imagegeneration@002
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_image_style: Option<String>,
|
||||
|
||||
/// Allow generation of people by the model. The following values are supported:
|
||||
/// - `"dont_allow"`: Disallow the inclusion of people or faces in images.
|
||||
/// - `"allow_adult"`: Allow generation of adults only.
|
||||
/// - `"allow_all"`: Allow generation of people of all ages.
|
||||
///
|
||||
/// The default value is `"allow_adult"`.
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub person_generation: Option<PersonGeneration>,
|
||||
|
||||
/// Optional. The language code that corresponds to your text prompt language.
|
||||
/// The following values are supported:
|
||||
/// - auto: Automatic detection. If Imagen detects a supported language, the prompt and an
|
||||
/// optional negative prompt are translated to English. If the language detected isn't
|
||||
/// supported, Imagen uses the input text verbatim, which might result in an unexpected
|
||||
/// output. No error code is returned.
|
||||
/// - en: English (if omitted, the default value)
|
||||
/// - zh or zh-CN: Chinese (simplified)
|
||||
/// - zh-TW: Chinese (traditional)
|
||||
/// - hi: Hindi
|
||||
/// - ja: Japanese
|
||||
/// - ko: Korean
|
||||
/// - pt: Portuguese
|
||||
/// - es: Spanish
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub language: Option<String>,
|
||||
|
||||
/// Adds a filter level to safety filtering. The following values are supported:
|
||||
///
|
||||
/// - "block_low_and_above": Strongest filtering level, most strict blocking.
|
||||
/// Deprecated value: "block_most".
|
||||
/// - "block_medium_and_above": Block some problematic prompts and responses.
|
||||
/// Deprecated value: "block_some".
|
||||
/// - "block_only_high": Reduces the number of requests blocked due to safety filters. May
|
||||
/// increase objectionable content generated by Imagen. Deprecated value: "block_few".
|
||||
/// - "block_none": Block very few problematic prompts and responses. Access to this feature
|
||||
/// is restricted. Previous field value: "block_fewest".
|
||||
///
|
||||
/// The default value is "block_medium_and_above".
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_setting: Option<PredictImageSafetySetting>,
|
||||
|
||||
/// Add an invisible watermark to the generated images. The default value is `false` for the
|
||||
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
|
||||
/// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub add_watermark: Option<bool>,
|
||||
|
||||
/// Cloud Storage URI to store the generated images.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub storage_uri: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageRequestParametersOutputOptions {
|
||||
/// The image format that the output should be saved as. The following values are supported:
|
||||
///
|
||||
/// - "image/png": Save as a PNG image
|
||||
/// - "image/jpeg": Save as a JPEG image
|
||||
///
|
||||
/// The default value is "image/png".v
|
||||
pub mime_type: Option<String>,
|
||||
|
||||
/// The level of compression if the output type is "image/jpeg".
|
||||
/// Accepted values are 0 through 100. The default value is 75.
|
||||
pub compression_quality: Option<i32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct PredictImageResponse {
|
||||
pub predictions: Vec<PredictImageResponsePrediction>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictImageResponsePrediction {
|
||||
#[serde_as(as = "Base64")]
|
||||
pub bytes_base64_encoded: Vec<u8>,
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PersonGeneration {
|
||||
DontAllow,
|
||||
AllowAdult,
|
||||
AllowAll,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PredictImageSafetySetting {
|
||||
BlockLowAndAbove,
|
||||
BlockMediumAndAbove,
|
||||
BlockOnlyHigh,
|
||||
BlockNone,
|
||||
}
|
||||
61
src/types/text_embeddings.rs
Normal file
61
src/types/text_embeddings.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{Error, Result};
|
||||
|
||||
use super::VertexApiError;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequest {
|
||||
pub instances: Vec<TextEmbeddingRequestInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequestInstance {
|
||||
pub content: String,
|
||||
pub task_type: String,
|
||||
pub title: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum TextEmbeddingResponse {
|
||||
Ok(TextEmbeddingResponseOk),
|
||||
Error { error: VertexApiError },
|
||||
}
|
||||
|
||||
impl TextEmbeddingResponse {
|
||||
pub fn into_result(self) -> Result<TextEmbeddingResponseOk> {
|
||||
self.into()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResponseOk {
|
||||
pub predictions: Vec<TextEmbeddingPrediction>,
|
||||
}
|
||||
|
||||
impl From<TextEmbeddingResponse> for Result<TextEmbeddingResponseOk> {
|
||||
fn from(value: TextEmbeddingResponse) -> Self {
|
||||
match value {
|
||||
TextEmbeddingResponse::Ok(ok) => Ok(ok),
|
||||
TextEmbeddingResponse::Error { error } => Err(Error::VertexError(error)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingPrediction {
|
||||
pub embeddings: TextEmbeddingResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResult {
|
||||
pub statistics: TextEmbeddingStatistics,
|
||||
pub values: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingStatistics {
|
||||
pub truncated: bool,
|
||||
pub token_count: u32,
|
||||
}
|
||||
Reference in New Issue
Block a user