Adds code for image generation
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -7,3 +7,4 @@ Cargo.lock
|
|||||||
.cargo/config.toml
|
.cargo/config.toml
|
||||||
.DS_Store
|
.DS_Store
|
||||||
/.idea
|
/.idea
|
||||||
|
output.jpg
|
||||||
|
|||||||
@@ -13,12 +13,14 @@ reqwest = { version = "0.12", features = ["json", "gzip"] }
|
|||||||
reqwest-eventsource = "0.6"
|
reqwest-eventsource = "0.6"
|
||||||
serde = { version = "1", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = { version = "1"}
|
serde_json = { version = "1"}
|
||||||
|
serde_with = { version = "3.9", features = ["base64"]}
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tokio = { version = "1" }
|
tokio = { version = "1" }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
console = "0.15.8"
|
console = "0.15.8"
|
||||||
dialoguer = "0.11.0"
|
dialoguer = "0.11.0"
|
||||||
|
image = "0.25.2"
|
||||||
indicatif = "0.17.8"
|
indicatif = "0.17.8"
|
||||||
tokio = { version = "1.37.0", features = ["full"] }
|
tokio = { version = "1.37.0", features = ["full"] }
|
||||||
tracing-subscriber = "0.3.18"
|
tracing-subscriber = "0.3.18"
|
||||||
|
|||||||
45
examples/generate_image.rs
Normal file
45
examples/generate_image.rs
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
use std::{error::Error, io::Cursor};
|
||||||
|
|
||||||
|
use gemini_rs::prelude::{
|
||||||
|
GeminiClient, PredictImageRequest, PredictImageRequestParameters, PredictImageRequestPrompt,
|
||||||
|
};
|
||||||
|
use image::{ImageFormat, ImageReader};
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
pub async fn main() -> Result<(), Box<dyn Error>> {
|
||||||
|
tracing_subscriber::fmt().init();
|
||||||
|
let authentication_manager = gcp_auth::provider().await?;
|
||||||
|
let api_endpoint = std::env::var("API_ENDPOINT")?;
|
||||||
|
let project_id = std::env::var("PROJECT_ID")?;
|
||||||
|
let location_id = std::env::var("LOCATION_ID")?;
|
||||||
|
|
||||||
|
let gemini = GeminiClient::new(
|
||||||
|
authentication_manager,
|
||||||
|
api_endpoint,
|
||||||
|
project_id,
|
||||||
|
location_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = "
|
||||||
|
Create an image of a tuxedo cat riding a rocket to the moon.";
|
||||||
|
let request = PredictImageRequest {
|
||||||
|
instances: vec![PredictImageRequestPrompt {
|
||||||
|
prompt: prompt.to_string(),
|
||||||
|
}],
|
||||||
|
parameters: PredictImageRequestParameters {
|
||||||
|
sample_count: 1,
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let mut result = gemini
|
||||||
|
.predict_image(&request, "imagen-3.0-fast-generate-001")
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let result = result.predictions.pop().unwrap();
|
||||||
|
|
||||||
|
let format = ImageFormat::from_mime_type(result.mime_type).unwrap();
|
||||||
|
let img =
|
||||||
|
ImageReader::with_format(Cursor::new(result.bytes_base64_encoded), format).decode()?;
|
||||||
|
img.save("output.jpg")?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ use crate::prelude::{
|
|||||||
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
GenerateContentResponse, GenerateContentResponseResult, GenerationConfig, TextEmbeddingRequest,
|
||||||
TextEmbeddingResponse,
|
TextEmbeddingResponse,
|
||||||
};
|
};
|
||||||
|
use crate::types::{PredictImageRequest, PredictImageResponse};
|
||||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||||
|
|
||||||
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
pub static AUTH_SCOPE: &[&str] = &["https://www.googleapis.com/auth/cloud-platform"];
|
||||||
@@ -252,4 +253,27 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
|||||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||||
Ok(serde_json::from_str(&txt_json)?)
|
Ok(serde_json::from_str(&txt_json)?)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn predict_image(
|
||||||
|
&self,
|
||||||
|
request: &PredictImageRequest,
|
||||||
|
model: &str,
|
||||||
|
) -> Result<PredictImageResponse> {
|
||||||
|
let endpoint_url = format!(
|
||||||
|
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
|
||||||
|
self.api_endpoint, self.project_id, self.location_id, model,
|
||||||
|
);
|
||||||
|
|
||||||
|
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(&endpoint_url)
|
||||||
|
.bearer_auth(access_token)
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let txt_json = resp.text().await?;
|
||||||
|
Ok(serde_json::from_str(&txt_json)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ mod common;
|
|||||||
mod count_tokens;
|
mod count_tokens;
|
||||||
mod error;
|
mod error;
|
||||||
mod generate_content;
|
mod generate_content;
|
||||||
|
mod predict_image;
|
||||||
mod text_embeddings;
|
mod text_embeddings;
|
||||||
|
|
||||||
pub use common::*;
|
pub use common::*;
|
||||||
pub use count_tokens::*;
|
pub use count_tokens::*;
|
||||||
pub use error::*;
|
pub use error::*;
|
||||||
pub use generate_content::*;
|
pub use generate_content::*;
|
||||||
|
pub use predict_image::*;
|
||||||
pub use text_embeddings::*;
|
pub use text_embeddings::*;
|
||||||
|
|||||||
142
src/types/predict_image.rs
Normal file
142
src/types/predict_image.rs
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_with::base64::Base64;
|
||||||
|
use serde_with::serde_as;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct PredictImageRequest {
|
||||||
|
pub instances: Vec<PredictImageRequestPrompt>,
|
||||||
|
pub parameters: PredictImageRequestParameters,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct PredictImageRequestPrompt {
|
||||||
|
/// The text prompt for the image.
|
||||||
|
/// The following models support different values for this parameter:
|
||||||
|
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||||
|
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||||
|
/// - `imagegeneration@006`: up to 128 tokens.
|
||||||
|
/// - `imagegeneration@005`: up to 128 tokens.
|
||||||
|
/// - `imagegeneration@002`: up to 64 tokens.
|
||||||
|
pub prompt: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PredictImageRequestParameters {
|
||||||
|
/// The number of images to generate. The default value is 4.
|
||||||
|
/// The following models support different values for this parameter:
|
||||||
|
/// - `imagen-3.0-generate-001`: 1 to 4.
|
||||||
|
/// - `imagen-3.0-fast-generate-001`: 1 to 4.
|
||||||
|
/// - `imagegeneration@006`: 1 to 4.
|
||||||
|
/// - `imagegeneration@005`: 1 to 4.
|
||||||
|
/// - `imagegeneration@002`: 1 to 8.
|
||||||
|
pub sample_count: i32,
|
||||||
|
|
||||||
|
/// The random seed for image generation. This is not available when addWatermark is set to
|
||||||
|
/// true.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub seed: Option<u32>,
|
||||||
|
|
||||||
|
/// A description of what to discourage in the generated images.
|
||||||
|
/// The following models support this parameter:
|
||||||
|
/// - `imagen-3.0-generate-001`: up to 480 tokens.
|
||||||
|
/// - `imagen-3.0-fast-generate-001`: up to 480 tokens.
|
||||||
|
/// - `imagegeneration@006`: up to 128 tokens.
|
||||||
|
/// - `imagegeneration@005`: up to 128 tokens.
|
||||||
|
/// - `imagegeneration@002`: up to 64 tokens.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub negative_prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The aspect ratio for the image. The default value is "1:1".
|
||||||
|
/// The following models support different values for this parameter:
|
||||||
|
/// - `imagen-3.0-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||||
|
/// - `imagen-3.0-fast-generate-001`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||||
|
/// - `imagegeneration@006`: "1:1", "9:16", "16:9", "3:4", or "4:3".
|
||||||
|
/// - `imagegeneration@005`: "1:1" or "9:16".
|
||||||
|
/// - `imagegeneration@002`: "1:1".
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub aspect_ratio: Option<String>,
|
||||||
|
|
||||||
|
/// Describes the output image format in an `PredictImageRequestParametersOutputOptions
|
||||||
|
/// object.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_options: Option<PredictImageRequestParametersOutputOptions>,
|
||||||
|
|
||||||
|
/// Describes the style for the generated images. The following values are supported:
|
||||||
|
/// - "photograph"
|
||||||
|
/// - "digital_art"
|
||||||
|
/// - "landscape"
|
||||||
|
/// - "sketch"
|
||||||
|
/// - "watercolor"
|
||||||
|
/// - "cyberpunk"
|
||||||
|
/// - "pop_art"
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sample_image_style: Option<String>,
|
||||||
|
|
||||||
|
/// Allow generation of people by the model. The following values are supported:
|
||||||
|
/// - `"dont_allow"`: Disallow the inclusion of people or faces in images.
|
||||||
|
/// - `"allow_adult"`: Allow generation of adults only.
|
||||||
|
/// - `"allow_all"`: Allow generation of people of all ages.
|
||||||
|
///
|
||||||
|
/// The default value is `"allow_adult"`.
|
||||||
|
///
|
||||||
|
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||||
|
/// `imagegeneration@006` only.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub person_generation: Option<String>,
|
||||||
|
|
||||||
|
/// Adds a filter level to safety filtering. The following values are supported:
|
||||||
|
/// - `"block_most"`: Strongest filtering level, most strict blocking.
|
||||||
|
/// - `"block_some"`: Block some problematic prompts and responses.
|
||||||
|
/// - `"block_few"`: Reduces the number of requests blocked due to safety filters. May
|
||||||
|
/// increase objectionable content generated by Imagen.
|
||||||
|
/// - `"block_fewest"`: Block very few problematic prompts and responses. Access to this
|
||||||
|
/// feature is restricted.
|
||||||
|
///
|
||||||
|
/// The default value is `"block_some"`.
|
||||||
|
///
|
||||||
|
/// Supported by the models `imagen-3.0-generate-001`, `imagen-3.0-fast-generate-001`, and
|
||||||
|
/// `imagegeneration@006` only.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub safety_setting: Option<String>,
|
||||||
|
|
||||||
|
/// Add an invisible watermark to the generated images. The default value is `false` for the
|
||||||
|
/// `imagegeneration@002` and `imagegeneration@005` models, and `true` for the
|
||||||
|
/// `imagen-3.0-fast-generate-001`, `imagegeneration@006`, and imagegeneration@006 models.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub add_watermark: Option<bool>,
|
||||||
|
|
||||||
|
/// Cloud Storage URI to store the generated images.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub storage_uri: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PredictImageRequestParametersOutputOptions {
|
||||||
|
/// The image format that the output should be saved as. The following values are supported:
|
||||||
|
///
|
||||||
|
/// - "image/png": Save as a PNG image
|
||||||
|
/// - "image/jpeg": Save as a JPEG image
|
||||||
|
///
|
||||||
|
/// The default value is "image/png".v
|
||||||
|
pub mime_type: Option<String>,
|
||||||
|
|
||||||
|
/// The level of compression if the output type is "image/jpeg".
|
||||||
|
/// Accepted values are 0 through 100. The default value is 75.
|
||||||
|
pub compression_quality: Option<i32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct PredictImageResponse {
|
||||||
|
pub predictions: Vec<PredictImageResponsePrediction>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[serde_as]
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PredictImageResponsePrediction {
|
||||||
|
#[serde_as(as = "Base64")]
|
||||||
|
pub bytes_base64_encoded: Vec<u8>,
|
||||||
|
pub mime_type: String,
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user