Compare commits
10 Commits
db5a01afef
...
1d467138ed
| Author | SHA1 | Date | |
|---|---|---|---|
| 1d467138ed | |||
| 9a8732f609 | |||
| 6e64b8fd72 | |||
| d1678bdc37 | |||
| 8ff379d040 | |||
| 51d6a27017 | |||
| de9b14b984 | |||
| 83663680c9 | |||
| fd1223da59 | |||
| 56d6b95c53 |
@@ -7,7 +7,6 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
deadqueue = "0.2"
|
||||
futures-util = "0.3"
|
||||
gcp_auth = "0.12"
|
||||
reqwest = { version = "0.12", features = ["json", "gzip"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
@@ -16,6 +15,7 @@ serde_json = { version = "1"}
|
||||
serde_with = { version = "3.9", features = ["base64"]}
|
||||
tracing = "0.1"
|
||||
tokio = { version = "1" }
|
||||
tokio-stream = "0.1.17"
|
||||
|
||||
[dev-dependencies]
|
||||
console = "0.15.8"
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use std::{error::Error, io::Cursor};
|
||||
|
||||
use gemini_rs::prelude::{
|
||||
GeminiClient, PredictImageRequest, PredictImageRequestParameters,
|
||||
GeminiClient, PersonGeneration, PredictImageRequest, PredictImageRequestParameters,
|
||||
PredictImageRequestParametersOutputOptions, PredictImageRequestPrompt,
|
||||
PredictImageSafetySetting,
|
||||
};
|
||||
use image::{ImageFormat, ImageReader};
|
||||
|
||||
@@ -34,9 +35,14 @@ pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||
mime_type: Some("image/jpeg".to_string()),
|
||||
compression_quality: Some(75),
|
||||
}),
|
||||
person_generation: Some(PersonGeneration::AllowAll),
|
||||
safety_setting: Some(PredictImageSafetySetting::BlockOnlyHigh),
|
||||
..Default::default()
|
||||
},
|
||||
};
|
||||
|
||||
println!("Request: {:#?}", serde_json::to_string(&request).unwrap());
|
||||
|
||||
let mut result = gemini
|
||||
.predict_image(&request, "imagen-3.0-fast-generate-001")
|
||||
.await?;
|
||||
|
||||
@@ -29,8 +29,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
println!(
|
||||
"Request: {}",
|
||||
serde_json::to_string_pretty(&request).unwrap()
|
||||
);
|
||||
|
||||
let result = gemini
|
||||
.generate_content(&request, "gemini-1.0-pro-002")
|
||||
.generate_content(&request, "gemini-1.5-flash-002")
|
||||
.await?;
|
||||
|
||||
println!("Response: {:?}", result.candidates[0].get_text().unwrap());
|
||||
|
||||
44
examples/google-search.rs
Normal file
44
examples/google-search.rs
Normal file
@@ -0,0 +1,44 @@
|
||||
use gemini_rs::prelude::*;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tracing_subscriber::fmt().init();
|
||||
let authentication_manager = gcp_auth::provider().await?;
|
||||
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||
let project_id = std::env::var("PROJECT_ID")?;
|
||||
let location_id = std::env::var("LOCATION_ID")?;
|
||||
|
||||
let gemini = GeminiClient::new(
|
||||
authentication_manager,
|
||||
api_endpoint,
|
||||
project_id,
|
||||
location_id,
|
||||
);
|
||||
|
||||
let prompt = "What day is today?";
|
||||
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
tools: Some(vec![Tools {
|
||||
google_search: Some(GoogleSearch::default()),
|
||||
..Default::default()
|
||||
}]),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
println!(
|
||||
"Request: {}",
|
||||
serde_json::to_string_pretty(&request).unwrap()
|
||||
);
|
||||
|
||||
let result = gemini
|
||||
.generate_content(&request, "gemini-2.0-flash-001")
|
||||
.await?;
|
||||
|
||||
println!("Response: {:?}", result.candidates[0].get_text().unwrap());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -15,20 +15,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
location_id,
|
||||
);
|
||||
|
||||
let system_instruction = "Answer as if you were Yoda";
|
||||
let prompt = "What is the airspeed of an unladen swallow?";
|
||||
let system_instruction = Content::builder()
|
||||
.add_text_part("Answer as if you were Yoda")
|
||||
.build();
|
||||
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
system_instruction: Some(Content {
|
||||
role: None,
|
||||
parts: Some(vec![Part::Text(system_instruction.to_string())]),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let user_prompt = vec![Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("What is the airspeed of an unladen swallow?")
|
||||
.build()];
|
||||
|
||||
let request = GenerateContentRequest::builder()
|
||||
.contents(user_prompt)
|
||||
.system_instruction(system_instruction)
|
||||
.build();
|
||||
|
||||
let result = gemini
|
||||
.generate_content(&request, "gemini-1.0-pro-002")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
use gemini_rs::prelude::*;
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
@@ -14,32 +15,19 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
location_id,
|
||||
);
|
||||
|
||||
let prompt = "Tell me the story of the genesis of the universe as a bedtime story.";
|
||||
let request = GenerateContentRequest::builder()
|
||||
.add_content(
|
||||
Content::builder()
|
||||
.role(Role::User)
|
||||
.add_part(Part::Text(prompt.to_string()))
|
||||
.build(),
|
||||
)
|
||||
.build();
|
||||
let prompt = vec![Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("Tell me the story of the genesis of the universe as a bedtime story.")
|
||||
.build()];
|
||||
|
||||
let queue = gemini.stream_generate_content(&request, "gemini-pro").await;
|
||||
let request = GenerateContentRequest::builder().contents(prompt).build();
|
||||
|
||||
while let Some(response) = queue.pop().await {
|
||||
match response {
|
||||
Ok(result) => {
|
||||
let text = result
|
||||
.candidates
|
||||
.iter()
|
||||
.filter_map(|c| c.get_text())
|
||||
.collect::<String>();
|
||||
print!("{}", text);
|
||||
}
|
||||
Err(error) => {
|
||||
println!("{error}");
|
||||
}
|
||||
}
|
||||
let mut queue = gemini
|
||||
.generate_content_stream(&request, "gemini-2.0-flash-001")
|
||||
.await?;
|
||||
|
||||
while let Some(Ok(response)) = queue.next().await {
|
||||
println!("Response: {:?}", response);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -14,15 +14,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
location_id,
|
||||
);
|
||||
|
||||
let prompt = "What is the airspeed of an unladen swallow?";
|
||||
let request = GenerateContentRequest::builder()
|
||||
.add_content(
|
||||
Content::builder()
|
||||
.role(Role::User)
|
||||
.add_part(Part::Text(prompt.to_string()))
|
||||
.build(),
|
||||
)
|
||||
.build();
|
||||
let prompt = vec![Content::builder()
|
||||
.role(Role::User)
|
||||
.add_text_part("What is the airspeed of an unladen swallow?")
|
||||
.build()];
|
||||
|
||||
let request = GenerateContentRequest::builder().contents(prompt).build();
|
||||
let response = gemini.generate_content(&request, "gemini-pro").await?;
|
||||
println!("Response: {:?}", response.candidates[0].get_text().unwrap());
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
use crate::error::Result as GeminiResult;
|
||||
use std::sync::Arc;
|
||||
use std::vec;
|
||||
use tokio_stream::{Stream, StreamExt};
|
||||
|
||||
use deadqueue::unlimited::Queue;
|
||||
use futures_util::stream::StreamExt;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use tracing::error;
|
||||
|
||||
@@ -9,7 +11,7 @@ use crate::dialogue::Message;
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||
GenerateContentResponse, GenerateContentResponseResult, TextEmbeddingRequest,
|
||||
TextEmbeddingResponse,
|
||||
};
|
||||
use crate::types::{PredictImageRequest, PredictImageResponse, Role};
|
||||
@@ -45,6 +47,53 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn generate_content_stream(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
model: &str,
|
||||
) -> Result<impl Stream<Item = GeminiResult<GenerateContentResponseResult>>> {
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await.unwrap();
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:streamGenerateContent?alt=sse", self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
let client = self.client.clone();
|
||||
let request = request.clone();
|
||||
let req = client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request);
|
||||
|
||||
let event_source = EventSource::new(req).unwrap();
|
||||
|
||||
let mapped = event_source.filter_map(|event| {
|
||||
let event = match event {
|
||||
Ok(event) => event,
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => {
|
||||
return Some(Err(Error::EventSourceClosedError))
|
||||
}
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let Event::Message(event_message) = event else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let gemini_response: GenerateContentResponse =
|
||||
match serde_json::from_str(&event_message.data) {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e.into())),
|
||||
};
|
||||
|
||||
let gemini_response = match gemini_response.into_result() {
|
||||
Ok(gemini_response) => gemini_response,
|
||||
Err(e) => return Some(Err(e)),
|
||||
};
|
||||
|
||||
Some(Ok(gemini_response))
|
||||
});
|
||||
Ok(mapped)
|
||||
}
|
||||
|
||||
pub async fn stream_generate_content(
|
||||
&self,
|
||||
request: &GenerateContentRequest,
|
||||
@@ -175,34 +224,6 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Sends a text prompt to the Vertex API using the Gemini Pro model and extracts the text
|
||||
/// from the response.
|
||||
#[deprecated(note = "Use `generate_content` instead")]
|
||||
pub async fn prompt_text(
|
||||
&self,
|
||||
prompt: &str,
|
||||
generation_config: Option<&GenerationConfig>,
|
||||
) -> Result<String> {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some(Role::User),
|
||||
parts: Some(vec![Part::Text(prompt.to_string())]),
|
||||
}],
|
||||
generation_config: generation_config.cloned(),
|
||||
tools: None,
|
||||
system_instruction: None,
|
||||
safety_settings: None,
|
||||
};
|
||||
|
||||
let response = self.generate_content(&request, "gemini-pro").await?;
|
||||
let mut candidates = GeminiClient::<T>::collect_text_from_response(&response);
|
||||
|
||||
match candidates.pop() {
|
||||
Some(candidate) => Ok(candidate),
|
||||
None => Err(Error::NoCandidatesError),
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_text_from_response(response: &GenerateContentResponseResult) -> Vec<String> {
|
||||
response
|
||||
.candidates
|
||||
@@ -278,10 +299,10 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
let txt_json = resp.text().await?;
|
||||
|
||||
match serde_json::from_str::<PredictImageResponse>(&txt_json) {
|
||||
Ok(response) => return Ok(response),
|
||||
Ok(response) => Ok(response),
|
||||
Err(e) => {
|
||||
error!(response = txt_json, error = ?e, "Failed to parse response");
|
||||
return Err(e.into());
|
||||
Err(e.into())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
16
src/error.rs
16
src/error.rs
@@ -14,7 +14,9 @@ pub enum Error {
|
||||
Serde(serde_json::Error),
|
||||
VertexError(types::VertexApiError),
|
||||
NoCandidatesError,
|
||||
EventSourceError(CannotCloneRequestError),
|
||||
CannotCloneRequestError(CannotCloneRequestError),
|
||||
EventSourceError(reqwest_eventsource::Error),
|
||||
EventSourceClosedError,
|
||||
}
|
||||
|
||||
impl Display for Error {
|
||||
@@ -30,9 +32,15 @@ impl Display for Error {
|
||||
Error::NoCandidatesError => {
|
||||
write!(f, "No candidates returned for the prompt")
|
||||
}
|
||||
Error::CannotCloneRequestError(e) => {
|
||||
write!(f, "Cannot clone request: {}", e)
|
||||
}
|
||||
Error::EventSourceError(e) => {
|
||||
write!(f, "EventSourrce Error: {}", e)
|
||||
}
|
||||
Error::EventSourceClosedError => {
|
||||
write!(f, "EventSource closed error")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,6 +79,12 @@ impl From<types::VertexApiError> for Error {
|
||||
|
||||
impl From<CannotCloneRequestError> for Error {
|
||||
fn from(e: CannotCloneRequestError) -> Self {
|
||||
Error::CannotCloneRequestError(e)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<reqwest_eventsource::Error> for Error {
|
||||
fn from(e: reqwest_eventsource::Error) -> Self {
|
||||
Error::EventSourceError(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ pub trait TokenProvider {
|
||||
-> impl std::future::Future<Output = Result<String>> + Send;
|
||||
}
|
||||
|
||||
impl<'a> TokenProvider for Arc<dyn gcp_auth::TokenProvider + 'a> {
|
||||
impl TokenProvider for Arc<dyn gcp_auth::TokenProvider + '_> {
|
||||
async fn get_token(&self, scope: &[&str]) -> Result<String> {
|
||||
let token = self.token(scope).await;
|
||||
match token {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::{collections::HashMap, str::FromStr, vec};
|
||||
use std::{collections::HashMap, fmt::Display, str::FromStr, vec};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -22,19 +22,18 @@ impl Content {
|
||||
}
|
||||
|
||||
pub fn builder() -> ContentBuilder {
|
||||
ContentBuilder::new()
|
||||
ContentBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ContentBuilder {
|
||||
content: Content,
|
||||
}
|
||||
|
||||
impl ContentBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
content: Default::default(),
|
||||
}
|
||||
pub fn add_text_part<T: Into<String>>(self, text: T) -> Self {
|
||||
self.add_part(Part::Text(text.into()))
|
||||
}
|
||||
|
||||
pub fn add_part(mut self, part: Part) -> Self {
|
||||
@@ -62,12 +61,13 @@ pub enum Role {
|
||||
Model,
|
||||
}
|
||||
|
||||
impl ToString for Role {
|
||||
fn to_string(&self) -> String {
|
||||
match self {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Model => "model".to_string(),
|
||||
}
|
||||
impl Display for Role {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let role_str = match self {
|
||||
Role::User => "user",
|
||||
Role::Model => "model",
|
||||
};
|
||||
f.write_str(role_str)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -9,21 +9,16 @@ pub struct CountTokensRequest {
|
||||
|
||||
impl CountTokensRequest {
|
||||
pub fn builder() -> CountTokensRequestBuilder {
|
||||
CountTokensRequestBuilder::new()
|
||||
CountTokensRequestBuilder::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CountTokensRequestBuilder {
|
||||
contents: Content,
|
||||
}
|
||||
|
||||
impl CountTokensRequestBuilder {
|
||||
pub fn new() -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_prompt(prompt: &str) -> Self {
|
||||
CountTokensRequestBuilder {
|
||||
contents: Content {
|
||||
|
||||
@@ -37,8 +37,8 @@ impl GenerateContentRequestBuilder {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_content(mut self, content: Content) -> Self {
|
||||
self.request.contents.push(content);
|
||||
pub fn contents(mut self, contents: Vec<Content>) -> Self {
|
||||
self.request.contents = contents;
|
||||
self
|
||||
}
|
||||
|
||||
@@ -71,15 +71,39 @@ impl GenerateContentRequestBuilder {
|
||||
pub struct Tools {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function_declarations: Option<Vec<FunctionDeclaration>>,
|
||||
|
||||
#[serde(rename = "googleSearchRetrieval")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub google_search_retrieval: Option<GoogleSearchRetrieval>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub google_search: Option<GoogleSearch>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
pub struct GoogleSearch {}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DynamicRetrievalConfig {
|
||||
pub mode: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dynamic_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
impl Default for DynamicRetrievalConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: "MODE_DYNAMIC".to_string(),
|
||||
dynamic_threshold: Some(0.7),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GoogleSearchRetrieval {
|
||||
pub disable_attribution: bool,
|
||||
pub dynamic_retrieval_config: DynamicRetrievalConfig,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize, Default)]
|
||||
|
||||
@@ -37,6 +37,12 @@ pub struct PredictImageRequestParameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub seed: Option<u32>,
|
||||
|
||||
/// Optional. An optional parameter to use an LLM-based prompt rewriting feature to deliver
|
||||
/// higher quality images that better reflect the original prompt's intent. Disabling this
|
||||
/// feature may impact image quality and prompt adherence
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enhance_prompt: Option<bool>,
|
||||
|
||||
/// A description of what to discourage in the generated images.
|
||||
/// The following models support this parameter:
|
||||
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||
@@ -70,6 +76,7 @@ pub struct PredictImageRequestParameters {
|
||||
/// - "watercolor"
|
||||
/// - "cyberpunk"
|
||||
/// - "pop_art"
|
||||
///
|
||||
/// Pre-defined styles is only supported for model imagegeneration@002
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_image_style: Option<String>,
|
||||
@@ -84,22 +91,42 @@ pub struct PredictImageRequestParameters {
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub person_generation: Option<String>,
|
||||
pub person_generation: Option<PersonGeneration>,
|
||||
|
||||
/// Optional. The language code that corresponds to your text prompt language.
|
||||
/// The following values are supported:
|
||||
/// - auto: Automatic detection. If Imagen detects a supported language, the prompt and an
|
||||
/// optional negative prompt are translated to English. If the language detected isn't
|
||||
/// supported, Imagen uses the input text verbatim, which might result in an unexpected
|
||||
/// output. No error code is returned.
|
||||
/// - en: English (if omitted, the default value)
|
||||
/// - zh or zh-CN: Chinese (simplified)
|
||||
/// - zh-TW: Chinese (traditional)
|
||||
/// - hi: Hindi
|
||||
/// - ja: Japanese
|
||||
/// - ko: Korean
|
||||
/// - pt: Portuguese
|
||||
/// - es: Spanish
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub language: Option<String>,
|
||||
|
||||
/// Adds a filter level to safety filtering. The following values are supported:
|
||||
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
||||
/// - `"block_some"`: Block some problematic prompts and responses.
|
||||
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
||||
/// increase objectionable content generated by Imagen.
|
||||
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
||||
/// feature is restricted.
|
||||
///
|
||||
/// The default value is `"block_some"`.
|
||||
/// - "block_low_and_above": Strongest filtering level, most strict blocking.
|
||||
/// Deprecated value: "block_most".
|
||||
/// - "block_medium_and_above": Block some problematic prompts and responses.
|
||||
/// Deprecated value: "block_some".
|
||||
/// - "block_only_high": Reduces the number of requests blocked due to safety filters. May
|
||||
/// increase objectionable content generated by Imagen. Deprecated value: "block_few".
|
||||
/// - "block_none": Block very few problematic prompts and responses. Access to this feature
|
||||
/// is restricted. Previous field value: "block_fewest".
|
||||
///
|
||||
/// The default value is "block_medium_and_above".
|
||||
///
|
||||
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||
/// `imagegeneration@006` only.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safety_setting: Option<String>,
|
||||
pub safety_setting: Option<PredictImageSafetySetting>,
|
||||
|
||||
/// Add an invisible watermark to the generated images. The default value is `false` for the
|
||||
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
|
||||
@@ -141,3 +168,20 @@ pub struct PredictImageResponsePrediction {
|
||||
pub bytes_base64_encoded: Vec<u8>,
|
||||
pub mime_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PersonGeneration {
|
||||
DontAllow,
|
||||
AllowAdult,
|
||||
AllowAll,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PredictImageSafetySetting {
|
||||
BlockLowAndAbove,
|
||||
BlockMediumAndAbove,
|
||||
BlockOnlyHigh,
|
||||
BlockNone,
|
||||
}
|
||||
|
||||
@@ -51,7 +51,7 @@ pub struct TextEmbeddingPrediction {
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResult {
|
||||
pub statistics: TextEmbeddingStatistics,
|
||||
pub values: Vec<f32>,
|
||||
pub values: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
|
||||
Reference in New Issue
Block a user