|
package gallery |
|
|
|
import ( |
|
"errors" |
|
"fmt" |
|
"os" |
|
"path/filepath" |
|
|
|
"dario.cat/mergo" |
|
lconfig "github.com/mudler/LocalAI/core/config" |
|
"github.com/mudler/LocalAI/pkg/downloader" |
|
"github.com/mudler/LocalAI/pkg/utils" |
|
|
|
"github.com/rs/zerolog/log" |
|
"gopkg.in/yaml.v2" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type Config struct { |
|
Description string `yaml:"description"` |
|
Icon string `yaml:"icon"` |
|
License string `yaml:"license"` |
|
URLs []string `yaml:"urls"` |
|
Name string `yaml:"name"` |
|
ConfigFile string `yaml:"config_file"` |
|
Files []File `yaml:"files"` |
|
PromptTemplates []PromptTemplate `yaml:"prompt_templates"` |
|
} |
|
|
|
type File struct { |
|
Filename string `yaml:"filename" json:"filename"` |
|
SHA256 string `yaml:"sha256" json:"sha256"` |
|
URI string `yaml:"uri" json:"uri"` |
|
} |
|
|
|
type PromptTemplate struct { |
|
Name string `yaml:"name"` |
|
Content string `yaml:"content"` |
|
} |
|
|
|
func GetGalleryConfigFromURL(url string, basePath string) (Config, error) { |
|
var config Config |
|
uri := downloader.URI(url) |
|
err := uri.DownloadAndUnmarshal(basePath, func(url string, d []byte) error { |
|
return yaml.Unmarshal(d, &config) |
|
}) |
|
if err != nil { |
|
log.Error().Err(err).Str("url", url).Msg("failed to get gallery config for url") |
|
return config, err |
|
} |
|
return config, nil |
|
} |
|
|
|
func ReadConfigFile(filePath string) (*Config, error) { |
|
|
|
yamlFile, err := os.ReadFile(filePath) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to read YAML file: %v", err) |
|
} |
|
|
|
|
|
var config Config |
|
err = yaml.Unmarshal(yamlFile, &config) |
|
if err != nil { |
|
return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) |
|
} |
|
|
|
return &config, nil |
|
} |
|
|
|
func InstallModel(basePath, nameOverride string, config *Config, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) error { |
|
|
|
err := os.MkdirAll(basePath, 0750) |
|
if err != nil { |
|
return fmt.Errorf("failed to create base path: %v", err) |
|
} |
|
|
|
if len(configOverrides) > 0 { |
|
log.Debug().Msgf("Config overrides %+v", configOverrides) |
|
} |
|
|
|
|
|
for i, file := range config.Files { |
|
log.Debug().Msgf("Checking %q exists and matches SHA", file.Filename) |
|
|
|
if err := utils.VerifyPath(file.Filename, basePath); err != nil { |
|
return err |
|
} |
|
|
|
|
|
filePath := filepath.Join(basePath, file.Filename) |
|
|
|
if enforceScan { |
|
scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) |
|
if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { |
|
log.Error().Str("model", config.Name).Strs("clamAV", scanResults.ClamAVInfectedFiles).Strs("pickles", scanResults.DangerousPickles).Msg("Contains unsafe file(s)!") |
|
return err |
|
} |
|
} |
|
uri := downloader.URI(file.URI) |
|
if err := uri.DownloadFile(filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { |
|
return err |
|
} |
|
} |
|
|
|
|
|
for _, template := range config.PromptTemplates { |
|
if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil { |
|
return err |
|
} |
|
|
|
filePath := filepath.Join(basePath, template.Name+".tmpl") |
|
|
|
|
|
err := os.MkdirAll(filepath.Dir(filePath), 0750) |
|
if err != nil { |
|
return fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) |
|
} |
|
|
|
err = os.WriteFile(filePath, []byte(template.Content), 0600) |
|
if err != nil { |
|
return fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) |
|
} |
|
|
|
log.Debug().Msgf("Prompt template %q written", template.Name) |
|
} |
|
|
|
name := config.Name |
|
if nameOverride != "" { |
|
name = nameOverride |
|
} |
|
|
|
if err := utils.VerifyPath(name+".yaml", basePath); err != nil { |
|
return err |
|
} |
|
|
|
|
|
if len(configOverrides) != 0 || len(config.ConfigFile) != 0 { |
|
configFilePath := filepath.Join(basePath, name+".yaml") |
|
|
|
|
|
configMap := make(map[string]interface{}) |
|
err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) |
|
if err != nil { |
|
return fmt.Errorf("failed to unmarshal config YAML: %v", err) |
|
} |
|
|
|
configMap["name"] = name |
|
|
|
if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { |
|
return err |
|
} |
|
|
|
|
|
updatedConfigYAML, err := yaml.Marshal(configMap) |
|
if err != nil { |
|
return fmt.Errorf("failed to marshal updated config YAML: %v", err) |
|
} |
|
|
|
backendConfig := lconfig.BackendConfig{} |
|
err = yaml.Unmarshal(updatedConfigYAML, &backendConfig) |
|
if err != nil { |
|
return fmt.Errorf("failed to unmarshal updated config YAML: %v", err) |
|
} |
|
if !backendConfig.Validate() { |
|
return fmt.Errorf("failed to validate updated config YAML") |
|
} |
|
|
|
err = os.WriteFile(configFilePath, updatedConfigYAML, 0600) |
|
if err != nil { |
|
return fmt.Errorf("failed to write updated config file: %v", err) |
|
} |
|
|
|
log.Debug().Msgf("Written config file %s", configFilePath) |
|
} |
|
|
|
|
|
modelFile := filepath.Join(basePath, galleryFileName(name)) |
|
data, err := yaml.Marshal(config) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
log.Debug().Msgf("Written gallery file %s", modelFile) |
|
|
|
return os.WriteFile(modelFile, data, 0600) |
|
|
|
|
|
} |
|
|
|
func galleryFileName(name string) string { |
|
return "._gallery_" + name + ".yaml" |
|
} |
|
|