Compare commits

..

10 Commits

Author SHA1 Message Date
1d467138ed Move predict image safety settings to enum
Some checks failed
Rust / build (push) Has been cancelled
2025-06-08 09:24:34 +01:00
9a8732f609 Move PersonGeneration parameter to enum 2025-06-08 09:11:01 +01:00
6e64b8fd72 Add missing parameters to imagegen 2025-06-08 08:58:48 +01:00
d1678bdc37 Define specific error for event stream ended 2025-04-06 08:12:58 +01:00
8ff379d040 Fixes lifetime for TokenProvider 2025-04-05 19:51:07 +01:00
51d6a27017 Adds streaming generation that returns a streaming
- Introducues generate_content_stream that returns a Tokio Steam instead
of a Queue. This allows using the standard stream APIs from
tokio-streams.
- Replace future-utils with tokio-streams, mainly due to better
ergonomics for using the filter_map stream combinator.
2025-04-05 19:18:33 +01:00
de9b14b984 Add support for the Google Search tool
- Gemini 2.0 and above use a google_search tool instead of a
google_search_retrieval tool. This commit adds support for the new tool.
2025-03-26 20:47:48 +00:00
83663680c9 Increase embedding precision to f64 2024-12-31 08:39:46 +00:00
fd1223da59 More builder refactoring 2024-11-27 16:20:53 +00:00
56d6b95c53 Add shortcuts for commonly used functions 2024-11-27 15:57:18 +00:00
15 changed files with 252 additions and 115 deletions

View File

@@ -7,7 +7,6 @@ edition = "2021"
[dependencies] [dependencies]
deadqueue = "0.2" deadqueue = "0.2"
futures-util = "0.3"
gcp_auth = "0.12" gcp_auth = "0.12"
reqwest = { version = "0.12", features = ["json", "gzip"] } reqwest = { version = "0.12", features = ["json", "gzip"] }
reqwest-eventsource = "0.6" reqwest-eventsource = "0.6"
@@ -16,6 +15,7 @@ serde_json = { version = "1"}
serde_with = { version = "3.9", features = ["base64"]} serde_with = { version = "3.9", features = ["base64"]}
tracing = "0.1" tracing = "0.1"
tokio = { version = "1" } tokio = { version = "1" }
tokio-stream = "0.1.17"
[dev-dependencies] [dev-dependencies]
console = "0.15.8" console = "0.15.8"

View File

@@ -1,8 +1,9 @@
use std::{error::Error, io::Cursor}; use std::{error::Error, io::Cursor};
use gemini_rs::prelude::{ use gemini_rs::prelude::{
GeminiClient, PredictImageRequest, PredictImageRequestParameters, GeminiClient, PersonGeneration, PredictImageRequest, PredictImageRequestParameters,
PredictImageRequestParametersOutputOptions, PredictImageRequestPrompt, PredictImageRequestParametersOutputOptions, PredictImageRequestPrompt,
PredictImageSafetySetting,
}; };
use image::{ImageFormat, ImageReader}; use image::{ImageFormat, ImageReader};
@@ -34,9 +35,14 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
mime_type: Some("image/jpeg".to_string()), mime_type: Some("image/jpeg".to_string()),
compression_quality: Some(75), compression_quality: Some(75),
}), }),
person_generation: Some(PersonGeneration::AllowAll),
safety_setting: Some(PredictImageSafetySetting::BlockOnlyHigh),
..Default::default() ..Default::default()
}, },
}; };
println!("Request: {:#?}", serde_json::to_string(&request).unwrap());
let mut result = gemini let mut result = gemini
.predict_image(&request, "imagen-3.0-fast-generate-001") .predict_image(&request, "imagen-3.0-fast-generate-001")
.await?; .await?;

View File

@@ -29,8 +29,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
..Default::default() ..Default::default()
}; };
println!(
"Request: {}",
serde_json::to_string_pretty(&request).unwrap()
);
let result = gemini let result = gemini
.generate_content(&request, "gemini-1.0-pro-002") .generate_content(&request, "gemini-1.5-flash-002")
.await?; .await?;
println!("Response: {:?}", result.candidates[0].get_text().unwrap()); println!("Response: {:?}", result.candidates[0].get_text().unwrap());

44
examples/google-search.rs Normal file
View File

@@ -0,0 +1,44 @@
use gemini_rs::prelude::*;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
tracing_subscriber::fmt().init();
let authentication_manager = gcp_auth::provider().await?;
let api_endpoint = std::env::var("API_ENDPOINT")?;
let project_id = std::env::var("PROJECT_ID")?;
let location_id = std::env::var("LOCATION_ID")?;
let gemini = GeminiClient::new(
authentication_manager,
api_endpoint,
project_id,
location_id,
);
let prompt = "What day is today?";
let request = GenerateContentRequest {
contents: vec![Content {
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
tools: Some(vec![Tools {
google_search: Some(GoogleSearch::default()),
..Default::default()
}]),
..Default::default()
};
println!(
"Request: {}",
serde_json::to_string_pretty(&request).unwrap()
);
let result = gemini
.generate_content(&request, "gemini-2.0-flash-001")
.await?;
println!("Response: {:?}", result.candidates[0].get_text().unwrap());
Ok(())
}

View File

@@ -15,20 +15,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
location_id, location_id,
); );
let system_instruction = "Answer as if you were Yoda"; let system_instruction = Content::builder()
let prompt = "What is the airspeed of an unladen swallow?"; .add_text_part("Answer as if you were Yoda")
.build();
let request = GenerateContentRequest { let user_prompt = vec![Content::builder()
contents: vec![Content { .role(Role::User)
role: Some(Role::User), .add_text_part("What is the airspeed of an unladen swallow?")
parts: Some(vec![Part::Text(prompt.to_string())]), .build()];
}],
system_instruction: Some(Content { let request = GenerateContentRequest::builder()
role: None, .contents(user_prompt)
parts: Some(vec![Part::Text(system_instruction.to_string())]), .system_instruction(system_instruction)
}), .build();
..Default::default()
};
let result = gemini let result = gemini
.generate_content(&request, "gemini-1.0-pro-002") .generate_content(&request, "gemini-1.0-pro-002")

View File

@@ -1,4 +1,5 @@
use gemini_rs::prelude::*; use gemini_rs::prelude::*;
use tokio_stream::StreamExt;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -14,32 +15,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
location_id, location_id,
); );
let prompt = "Tell me the story of the genesis of the universe as a bedtime story."; let prompt = vec![Content::builder()
let request = GenerateContentRequest::builder() .role(Role::User)
.add_content( .add_text_part("Tell me the story of the genesis of the universe as a bedtime story.")
Content::builder() .build()];
.role(Role::User)
.add_part(Part::Text(prompt.to_string()))
.build(),
)
.build();
let queue = gemini.stream_generate_content(&request, "gemini-pro").await; let request = GenerateContentRequest::builder().contents(prompt).build();
while let Some(response) = queue.pop().await { let mut queue = gemini
match response { .generate_content_stream(&request, "gemini-2.0-flash-001")
Ok(result) => { .await?;
let text = result
.candidates while let Some(Ok(response)) = queue.next().await {
.iter() println!("Response: {:?}", response);
.filter_map(|c| c.get_text())
.collect::<String>();
print!("{}", text);
}
Err(error) => {
println!("{error}");
}
}
} }
Ok(()) Ok(())

View File

@@ -14,15 +14,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
location_id, location_id,
); );
let prompt = "What is the airspeed of an unladen swallow?"; let prompt = vec![Content::builder()
let request = GenerateContentRequest::builder() .role(Role::User)
.add_content( .add_text_part("What is the airspeed of an unladen swallow?")
Content::builder() .build()];
.role(Role::User)
.add_part(Part::Text(prompt.to_string())) let request = GenerateContentRequest::builder().contents(prompt).build();
.build(),
)
.build();
let response = gemini.generate_content(&request, "gemini-pro").await?; let response = gemini.generate_content(&request, "gemini-pro").await?;
println!("Response: {:?}", response.candidates[0].get_text().unwrap()); println!("Response: {:?}", response.candidates[0].get_text().unwrap());

View File

@@ -1,7 +1,9 @@
use crate::error::Result as GeminiResult;
use std::sync::Arc; use std::sync::Arc;
use std::vec;
use tokio_stream::{Stream, StreamExt};
use deadqueue::unlimited::Queue; use deadqueue::unlimited::Queue;
use futures_util::stream::StreamExt;
use reqwest_eventsource::{Event, EventSource}; use reqwest_eventsource::{Event, EventSource};
use tracing::error; use tracing::error;
@@ -9,7 +11,7 @@ use crate::dialogue::Message;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::prelude::{ use crate::prelude::{
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest, Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest, GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
TextEmbeddingResponse, TextEmbeddingResponse,
}; };
use crate::types::{PredictImageRequest, PredictImageResponse, Role}; use crate::types::{PredictImageRequest, PredictImageResponse, Role};
@@ -45,6 +47,53 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
} }
pub async fn generate_content_stream(
&self,
request: &GenerateContentRequest,
model: &str,
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
let endpoint_url = format!(
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model,
);
let client = self.client.clone();
let request = request.clone();
let req = client
.post(&endpoint_url)
.bearer_auth(access_token)
.json(&request);
let event_source = EventSource::new(req).unwrap();
let mapped = event_source.filter_map(|event| {
let event = match event {
Ok(event) => event,
Err(reqwest_eventsource::Error::StreamEnded) => {
return Some(Err(Error::EventSourceClosedError))
}
Err(e) => return Some(Err(e.into())),
};
let Event::Message(event_message) = event else {
return None;
};
let gemini_response: GenerateContentResponse =
match serde_json::from_str(&event_message.data) {
Ok(gemini_response) => gemini_response,
Err(e) => return Some(Err(e.into())),
};
let gemini_response = match gemini_response.into_result() {
Ok(gemini_response) => gemini_response,
Err(e) => return Some(Err(e)),
};
Some(Ok(gemini_response))
});
Ok(mapped)
}
pub async fn stream_generate_content( pub async fn stream_generate_content(
&self, &self,
request: &GenerateContentRequest, request: &GenerateContentRequest,
@@ -175,34 +224,6 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
} }
} }
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
/// from the response.
#[deprecated(note = "Use `generate_content` instead")]
pub async fn prompt_text(
&self,
prompt: &str,
generation_config: Option<&GenerationConfig>,
) -> Result<String> {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some(Role::User),
parts: Some(vec![Part::Text(prompt.to_string())]),
}],
generation_config: generation_config.cloned(),
tools: None,
system_instruction: None,
safety_settings: None,
};
let response = self.generate_content(&request, "gemini-pro").await?;
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
match candidates.pop() {
Some(candidate) => Ok(candidate),
None => Err(Error::NoCandidatesError),
}
}
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> { fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
response response
.candidates .candidates
@@ -278,10 +299,10 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
let txt_json = resp.text().await?; let txt_json = resp.text().await?;
match serde_json::from_str::<PredictImageResponse>(&txt_json) { match serde_json::from_str::<PredictImageResponse>(&txt_json) {
Ok(response) => return Ok(response), Ok(response) => Ok(response),
Err(e) => { Err(e) => {
error!(response = txt_json, error = ?e, "Failed to parse response"); error!(response = txt_json, error = ?e, "Failed to parse response");
return Err(e.into()); Err(e.into())
} }
} }
} }

View File

@@ -14,7 +14,9 @@ pub enum Error {
Serde(serde_json::Error), Serde(serde_json::Error),
VertexError(types::VertexApiError), VertexError(types::VertexApiError),
NoCandidatesError, NoCandidatesError,
EventSourceError(CannotCloneRequestError), CannotCloneRequestError(CannotCloneRequestError),
EventSourceError(reqwest_eventsource::Error),
EventSourceClosedError,
} }
impl Display for Error { impl Display for Error {
@@ -30,9 +32,15 @@ impl Display for Error {
Error::NoCandidatesError => { Error::NoCandidatesError => {
write!(f, "No candidates returned for the prompt") write!(f, "No candidates returned for the prompt")
} }
Error::CannotCloneRequestError(e) => {
write!(f, "Cannot clone request: {}", e)
}
Error::EventSourceError(e) => { Error::EventSourceError(e) => {
write!(f, "EventSourrce Error: {}", e) write!(f, "EventSourrce Error: {}", e)
} }
Error::EventSourceClosedError => {
write!(f, "EventSource closed error")
}
} }
} }
} }
@@ -71,6 +79,12 @@ impl From<types::VertexApiError> for Error {
impl From<CannotCloneRequestError> for Error { impl From<CannotCloneRequestError> for Error {
fn from(e: CannotCloneRequestError) -> Self { fn from(e: CannotCloneRequestError) -> Self {
Error::CannotCloneRequestError(e)
}
}
impl From<reqwest_eventsource::Error> for Error {
fn from(e: reqwest_eventsource::Error) -> Self {
Error::EventSourceError(e) Error::EventSourceError(e)
} }
} }

View File

@@ -7,7 +7,7 @@ pub trait TokenProvider {
-> impl std::future::Future<Output = Result<String>> + Send; -> impl std::future::Future<Output = Result<String>> + Send;
} }
impl<'a> TokenProvider for Arc<dyn gcp_auth::TokenProvider + 'a> { impl TokenProvider for Arc<dyn gcp_auth::TokenProvider + '_> {
async fn get_token(&self, scope: &[&str]) -> Result<String> { async fn get_token(&self, scope: &[&str]) -> Result<String> {
let token = self.token(scope).await; let token = self.token(scope).await;
match token { match token {

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, str::FromStr, vec}; use std::{collections::HashMap, fmt::Display, str::FromStr, vec};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -22,19 +22,18 @@ impl Content {
} }
pub fn builder() -> ContentBuilder { pub fn builder() -> ContentBuilder {
ContentBuilder::new() ContentBuilder::default()
} }
} }
#[derive(Default)]
pub struct ContentBuilder { pub struct ContentBuilder {
content: Content, content: Content,
} }
impl ContentBuilder { impl ContentBuilder {
pub fn new() -> Self { pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
Self { self.add_part(Part::Text(text.into()))
content: Default::default(),
}
} }
pub fn add_part(mut self, part: Part) -> Self { pub fn add_part(mut self, part: Part) -> Self {
@@ -62,12 +61,13 @@ pub enum Role {
Model, Model,
} }
impl ToString for Role { impl Display for Role {
fn to_string(&self) -> String { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { let role_str = match self {
Role::User => "user".to_string(), Role::User => "user",
Role::Model => "model".to_string(), Role::Model => "model",
} };
f.write_str(role_str)
} }
} }

View File

@@ -9,21 +9,16 @@ pub struct CountTokensRequest {
impl CountTokensRequest { impl CountTokensRequest {
pub fn builder() -> CountTokensRequestBuilder { pub fn builder() -> CountTokensRequestBuilder {
CountTokensRequestBuilder::new() CountTokensRequestBuilder::default()
} }
} }
#[derive(Default)]
pub struct CountTokensRequestBuilder { pub struct CountTokensRequestBuilder {
contents: Content, contents: Content,
} }
impl CountTokensRequestBuilder { impl CountTokensRequestBuilder {
pub fn new() -> Self {
CountTokensRequestBuilder {
contents: Content::default(),
}
}
pub fn from_prompt(prompt: &str) -> Self { pub fn from_prompt(prompt: &str) -> Self {
CountTokensRequestBuilder { CountTokensRequestBuilder {
contents: Content { contents: Content {

View File

@@ -37,8 +37,8 @@ impl GenerateContentRequestBuilder {
} }
} }
pub fn add_content(mut self, content: Content) -> Self { pub fn contents(mut self, contents: Vec<Content>) -> Self {
self.request.contents.push(content); self.request.contents = contents;
self self
} }
@@ -71,15 +71,39 @@ impl GenerateContentRequestBuilder {
pub struct Tools { pub struct Tools {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub function_declarations: Option<Vec<FunctionDeclaration>>, pub function_declarations: Option<Vec<FunctionDeclaration>>,
#[serde(rename = "googleSearchRetrieval")] #[serde(rename = "googleSearchRetrieval")]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub google_search_retrieval: Option<GoogleSearchRetrieval>, pub google_search_retrieval: Option<GoogleSearchRetrieval>,
#[serde(skip_serializing_if = "Option::is_none")]
pub google_search: Option<GoogleSearch>,
}
#[derive(Clone, Default, Serialize, Deserialize)]
pub struct GoogleSearch {}
#[derive(Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DynamicRetrievalConfig {
pub mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub dynamic_threshold: Option<f32>,
}
impl Default for DynamicRetrievalConfig {
fn default() -> Self {
Self {
mode: "MODE_DYNAMIC".to_string(),
dynamic_threshold: Some(0.7),
}
}
} }
#[derive(Clone, Default, Serialize, Deserialize)] #[derive(Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
pub struct GoogleSearchRetrieval { pub struct GoogleSearchRetrieval {
pub disable_attribution: bool, pub dynamic_retrieval_config: DynamicRetrievalConfig,
} }
#[derive(Clone, Debug, Serialize, Deserialize, Default)] #[derive(Clone, Debug, Serialize, Deserialize, Default)]

View File

@@ -37,6 +37,12 @@ pub struct PredictImageRequestParameters {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub seed: Option<u32>, pub seed: Option<u32>,
/// Optional. An optional parameter to use an LLM-based prompt rewriting feature to deliver
/// higher quality images that better reflect the original prompt's intent. Disabling this
/// feature may impact image quality and prompt adherence
#[serde(skip_serializing_if = "Option::is_none")]
pub enhance_prompt: Option<bool>,
/// A description of what to discourage in the generated images. /// A description of what to discourage in the generated images.
/// The following models support this parameter: /// The following models support this parameter:
/// - `imagen-3.0-generate-001`: up to 480 tokens. /// - `imagen-3.0-generate-001`: up to 480 tokens.
@@ -70,6 +76,7 @@ pub struct PredictImageRequestParameters {
/// - "watercolor" /// - "watercolor"
/// - "cyberpunk" /// - "cyberpunk"
/// - "pop_art" /// - "pop_art"
///
/// Pre-defined styles is only supported for model imagegeneration@002 /// Pre-defined styles is only supported for model imagegeneration@002
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub sample_image_style: Option<String>, pub sample_image_style: Option<String>,
@@ -84,22 +91,42 @@ pub struct PredictImageRequestParameters {
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
/// `imagegeneration@006` only. /// `imagegeneration@006` only.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub person_generation: Option<String>, pub person_generation: Option<PersonGeneration>,
/// Optional. The language code that corresponds to your text prompt language.
/// The following values are supported:
/// - auto: Automatic detection. If Imagen detects a supported language, the prompt and an
/// optional negative prompt are translated to English. If the language detected isn't
/// supported, Imagen uses the input text verbatim, which might result in an unexpected
/// output. No error code is returned.
/// - en: English (if omitted, the default value)
/// - zh or zh-CN: Chinese (simplified)
/// - zh-TW: Chinese (traditional)
/// - hi: Hindi
/// - ja: Japanese
/// - ko: Korean
/// - pt: Portuguese
/// - es: Spanish
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
/// Adds a filter level to safety filtering. The following values are supported: /// Adds a filter level to safety filtering. The following values are supported:
/// - `"block_most"`: Strongest filtering level, most strict blocking.
/// - `"block_some"`: Block some problematic prompts and responses.
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
/// increase objectionable content generated by Imagen.
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
/// feature is restricted.
/// ///
/// The default value is `"block_some"`. /// - "block_low_and_above": Strongest filtering level, most strict blocking.
/// Deprecated value: "block_most".
/// - "block_medium_and_above": Block some problematic prompts and responses.
/// Deprecated value: "block_some".
/// - "block_only_high": Reduces the number of requests blocked due to safety filters. May
/// increase objectionable content generated by Imagen. Deprecated value: "block_few".
/// - "block_none": Block very few problematic prompts and responses. Access to this feature
/// is restricted. Previous field value: "block_fewest".
///
/// The default value is "block_medium_and_above".
/// ///
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and /// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
/// `imagegeneration@006` only. /// `imagegeneration@006` only.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub safety_setting: Option<String>, pub safety_setting: Option<PredictImageSafetySetting>,
/// Add an invisible watermark to the generated images. The default value is `false` for the /// Add an invisible watermark to the generated images. The default value is `false` for the
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the /// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
@@ -141,3 +168,20 @@ pub struct PredictImageResponsePrediction {
pub bytes_base64_encoded: Vec<u8>, pub bytes_base64_encoded: Vec<u8>,
pub mime_type: String, pub mime_type: String,
} }
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PersonGeneration {
DontAllow,
AllowAdult,
AllowAll,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PredictImageSafetySetting {
BlockLowAndAbove,
BlockMediumAndAbove,
BlockOnlyHigh,
BlockNone,
}

View File

@@ -51,7 +51,7 @@ pub struct TextEmbeddingPrediction {
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TextEmbeddingResult { pub struct TextEmbeddingResult {
pub statistics: TextEmbeddingStatistics, pub statistics: TextEmbeddingStatistics,
pub values: Vec<f32>, pub values: Vec<f64>,
} }
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]